mirror of
				https://github.com/nonebot/nonebot2.git
				synced 2025-11-03 16:36:44 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			360 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			360 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import abc
 | 
						|
import urllib.request
 | 
						|
from enum import Enum
 | 
						|
from dataclasses import dataclass
 | 
						|
from typing_extensions import TypeAlias
 | 
						|
from http.cookiejar import Cookie, CookieJar
 | 
						|
from typing import IO, Any, Union, Callable, Optional
 | 
						|
from collections.abc import Mapping, Iterator, Awaitable, MutableMapping
 | 
						|
 | 
						|
from yarl import URL as URL
 | 
						|
from multidict import CIMultiDict
 | 
						|
 | 
						|
RawURL: TypeAlias = tuple[bytes, bytes, Optional[int], bytes]
 | 
						|
 | 
						|
SimpleQuery: TypeAlias = Union[str, int, float]
 | 
						|
QueryVariable: TypeAlias = Union[SimpleQuery, list[SimpleQuery]]
 | 
						|
QueryTypes: TypeAlias = Union[
 | 
						|
    None, str, Mapping[str, QueryVariable], list[tuple[str, SimpleQuery]]
 | 
						|
]
 | 
						|
 | 
						|
HeaderTypes: TypeAlias = Union[
 | 
						|
    None,
 | 
						|
    CIMultiDict[str],
 | 
						|
    dict[str, str],
 | 
						|
    list[tuple[str, str]],
 | 
						|
]
 | 
						|
 | 
						|
CookieTypes: TypeAlias = Union[
 | 
						|
    None, "Cookies", CookieJar, dict[str, str], list[tuple[str, str]]
 | 
						|
]
 | 
						|
 | 
						|
ContentTypes: TypeAlias = Union[str, bytes, None]
 | 
						|
DataTypes: TypeAlias = Union[dict, None]
 | 
						|
FileContent: TypeAlias = Union[IO[bytes], bytes]
 | 
						|
FileType: TypeAlias = tuple[Optional[str], FileContent, Optional[str]]
 | 
						|
FileTypes: TypeAlias = Union[
 | 
						|
    # file (or bytes)
 | 
						|
    FileContent,
 | 
						|
    # (filename, file (or bytes))
 | 
						|
    tuple[Optional[str], FileContent],
 | 
						|
    # (filename, file (or bytes), content_type)
 | 
						|
    FileType,
 | 
						|
]
 | 
						|
FilesTypes: TypeAlias = Union[dict[str, FileTypes], list[tuple[str, FileTypes]], None]
 | 
						|
 | 
						|
 | 
						|
class HTTPVersion(Enum):
 | 
						|
    H10 = "1.0"
 | 
						|
    H11 = "1.1"
 | 
						|
    H2 = "2"
 | 
						|
 | 
						|
 | 
						|
class Request:
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        method: Union[str, bytes],
 | 
						|
        url: Union["URL", str, RawURL],
 | 
						|
        *,
 | 
						|
        params: QueryTypes = None,
 | 
						|
        headers: HeaderTypes = None,
 | 
						|
        cookies: CookieTypes = None,
 | 
						|
        content: ContentTypes = None,
 | 
						|
        data: DataTypes = None,
 | 
						|
        json: Any = None,
 | 
						|
        files: FilesTypes = None,
 | 
						|
        version: Union[str, HTTPVersion] = HTTPVersion.H11,
 | 
						|
        timeout: Optional[float] = None,
 | 
						|
        proxy: Optional[str] = None,
 | 
						|
    ):
 | 
						|
        # method
 | 
						|
        self.method: str = (
 | 
						|
            method.decode("ascii").upper()
 | 
						|
            if isinstance(method, bytes)
 | 
						|
            else method.upper()
 | 
						|
        )
 | 
						|
        # http version
 | 
						|
        self.version: HTTPVersion = HTTPVersion(version)
 | 
						|
        # timeout
 | 
						|
        self.timeout: Optional[float] = timeout
 | 
						|
        # proxy
 | 
						|
        self.proxy: Optional[str] = proxy
 | 
						|
 | 
						|
        # url
 | 
						|
        if isinstance(url, tuple):
 | 
						|
            scheme, host, port, path = url
 | 
						|
            url = URL.build(
 | 
						|
                scheme=scheme.decode("ascii"),
 | 
						|
                host=host.decode("ascii"),
 | 
						|
                port=port,
 | 
						|
                path=path.decode("ascii"),
 | 
						|
            )
 | 
						|
        else:
 | 
						|
            url = URL(url)
 | 
						|
 | 
						|
        if params is not None:
 | 
						|
            url = url.update_query(params)
 | 
						|
        self.url: URL = url
 | 
						|
 | 
						|
        # headers
 | 
						|
        self.headers: CIMultiDict[str] = (
 | 
						|
            CIMultiDict(headers) if headers is not None else CIMultiDict()
 | 
						|
        )
 | 
						|
        # cookies
 | 
						|
        self.cookies = Cookies(cookies)
 | 
						|
 | 
						|
        # body
 | 
						|
        self.content: ContentTypes = content
 | 
						|
        self.data: DataTypes = data
 | 
						|
        self.json: Any = json
 | 
						|
        self.files: Optional[list[tuple[str, FileType]]] = None
 | 
						|
        if files:
 | 
						|
            self.files = []
 | 
						|
            files_ = files.items() if isinstance(files, dict) else files
 | 
						|
            for name, file_info in files_:
 | 
						|
                if not isinstance(file_info, tuple):
 | 
						|
                    self.files.append((name, (name, file_info, None)))
 | 
						|
                elif len(file_info) == 2:
 | 
						|
                    self.files.append((name, (file_info[0], file_info[1], None)))
 | 
						|
                else:
 | 
						|
                    self.files.append((name, file_info))  # type: ignore
 | 
						|
 | 
						|
    def __repr__(self) -> str:
 | 
						|
        return f"{self.__class__.__name__}(method={self.method!r}, url='{self.url!s}')"
 | 
						|
 | 
						|
 | 
						|
class Response:
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        status_code: int,
 | 
						|
        *,
 | 
						|
        headers: HeaderTypes = None,
 | 
						|
        content: ContentTypes = None,
 | 
						|
        request: Optional[Request] = None,
 | 
						|
    ):
 | 
						|
        # status code
 | 
						|
        self.status_code: int = status_code
 | 
						|
 | 
						|
        # headers
 | 
						|
        self.headers: CIMultiDict[str] = (
 | 
						|
            CIMultiDict(headers) if headers is not None else CIMultiDict()
 | 
						|
        )
 | 
						|
        # body
 | 
						|
        self.content: ContentTypes = content
 | 
						|
 | 
						|
        # request
 | 
						|
        self.request: Optional[Request] = request
 | 
						|
 | 
						|
    def __repr__(self) -> str:
 | 
						|
        return f"{self.__class__.__name__}(status_code={self.status_code!r})"
 | 
						|
 | 
						|
 | 
						|
class WebSocket(abc.ABC):
 | 
						|
    def __init__(self, *, request: Request):
 | 
						|
        self.request: Request = request
 | 
						|
 | 
						|
    def __repr__(self) -> str:
 | 
						|
        return f"{self.__class__.__name__}('{self.request.url!s}')"
 | 
						|
 | 
						|
    @property
 | 
						|
    @abc.abstractmethod
 | 
						|
    def closed(self) -> bool:
 | 
						|
        """连接是否已经关闭"""
 | 
						|
        raise NotImplementedError
 | 
						|
 | 
						|
    @abc.abstractmethod
 | 
						|
    async def accept(self) -> None:
 | 
						|
        """接受 WebSocket 连接请求"""
 | 
						|
        raise NotImplementedError
 | 
						|
 | 
						|
    @abc.abstractmethod
 | 
						|
    async def close(self, code: int = 1000, reason: str = "") -> None:
 | 
						|
        """关闭 WebSocket 连接请求"""
 | 
						|
        raise NotImplementedError
 | 
						|
 | 
						|
    @abc.abstractmethod
 | 
						|
    async def receive(self) -> Union[str, bytes]:
 | 
						|
        """接收一条 WebSocket text/bytes 信息"""
 | 
						|
        raise NotImplementedError
 | 
						|
 | 
						|
    @abc.abstractmethod
 | 
						|
    async def receive_text(self) -> str:
 | 
						|
        """接收一条 WebSocket text 信息"""
 | 
						|
        raise NotImplementedError
 | 
						|
 | 
						|
    @abc.abstractmethod
 | 
						|
    async def receive_bytes(self) -> bytes:
 | 
						|
        """接收一条 WebSocket binary 信息"""
 | 
						|
        raise NotImplementedError
 | 
						|
 | 
						|
    async def send(self, data: Union[str, bytes]) -> None:
 | 
						|
        """发送一条 WebSocket text/bytes 信息"""
 | 
						|
        if isinstance(data, str):
 | 
						|
            await self.send_text(data)
 | 
						|
        elif isinstance(data, bytes):
 | 
						|
            await self.send_bytes(data)
 | 
						|
        else:
 | 
						|
            raise TypeError("WebSocker send method expects str or bytes!")
 | 
						|
 | 
						|
    @abc.abstractmethod
 | 
						|
    async def send_text(self, data: str) -> None:
 | 
						|
        """发送一条 WebSocket text 信息"""
 | 
						|
        raise NotImplementedError
 | 
						|
 | 
						|
    @abc.abstractmethod
 | 
						|
    async def send_bytes(self, data: bytes) -> None:
 | 
						|
        """发送一条 WebSocket binary 信息"""
 | 
						|
        raise NotImplementedError
 | 
						|
 | 
						|
 | 
						|
class Cookies(MutableMapping):
 | 
						|
    def __init__(self, cookies: CookieTypes = None) -> None:
 | 
						|
        self.jar: CookieJar = cookies if isinstance(cookies, CookieJar) else CookieJar()
 | 
						|
        if cookies is not None and not isinstance(cookies, CookieJar):
 | 
						|
            if isinstance(cookies, dict):
 | 
						|
                for key, value in cookies.items():
 | 
						|
                    self.set(key, value)
 | 
						|
            elif isinstance(cookies, list):
 | 
						|
                for key, value in cookies:
 | 
						|
                    self.set(key, value)
 | 
						|
            elif isinstance(cookies, Cookies):
 | 
						|
                for cookie in cookies.jar:
 | 
						|
                    self.jar.set_cookie(cookie)
 | 
						|
            else:
 | 
						|
                raise TypeError(f"Cookies must be dict or list, not {type(cookies)}")
 | 
						|
 | 
						|
    def set(self, name: str, value: str, domain: str = "", path: str = "/") -> None:
 | 
						|
        cookie = Cookie(
 | 
						|
            version=0,
 | 
						|
            name=name,
 | 
						|
            value=value,
 | 
						|
            port=None,
 | 
						|
            port_specified=False,
 | 
						|
            domain=domain,
 | 
						|
            domain_specified=bool(domain),
 | 
						|
            domain_initial_dot=domain.startswith("."),
 | 
						|
            path=path,
 | 
						|
            path_specified=bool(path),
 | 
						|
            secure=False,
 | 
						|
            expires=None,
 | 
						|
            discard=True,
 | 
						|
            comment=None,
 | 
						|
            comment_url=None,
 | 
						|
            rest={},
 | 
						|
            rfc2109=False,
 | 
						|
        )
 | 
						|
        self.jar.set_cookie(cookie)
 | 
						|
 | 
						|
    def get(  # pyright: ignore[reportIncompatibleMethodOverride]
 | 
						|
        self,
 | 
						|
        name: str,
 | 
						|
        default: Optional[str] = None,
 | 
						|
        domain: Optional[str] = None,
 | 
						|
        path: Optional[str] = None,
 | 
						|
    ) -> Optional[str]:
 | 
						|
        value: Optional[str] = None
 | 
						|
        for cookie in self.jar:
 | 
						|
            if (
 | 
						|
                cookie.name == name
 | 
						|
                and (domain is None or cookie.domain == domain)
 | 
						|
                and (path is None or cookie.path == path)
 | 
						|
            ):
 | 
						|
                if value is not None:
 | 
						|
                    message = f"Multiple cookies exist with name={name}"
 | 
						|
                    raise ValueError(message)
 | 
						|
                value = cookie.value
 | 
						|
 | 
						|
        return default if value is None else value
 | 
						|
 | 
						|
    def delete(
 | 
						|
        self, name: str, domain: Optional[str] = None, path: Optional[str] = None
 | 
						|
    ) -> None:
 | 
						|
        if domain is not None and path is not None:
 | 
						|
            return self.jar.clear(domain, path, name)
 | 
						|
 | 
						|
        remove = [
 | 
						|
            cookie
 | 
						|
            for cookie in self.jar
 | 
						|
            if cookie.name == name
 | 
						|
            and (domain is None or cookie.domain == domain)
 | 
						|
            and (path is None or cookie.path == path)
 | 
						|
        ]
 | 
						|
 | 
						|
        for cookie in remove:
 | 
						|
            self.jar.clear(cookie.domain, cookie.path, cookie.name)
 | 
						|
 | 
						|
    def clear(self, domain: Optional[str] = None, path: Optional[str] = None) -> None:
 | 
						|
        self.jar.clear(domain, path)
 | 
						|
 | 
						|
    def update(  # pyright: ignore[reportIncompatibleMethodOverride]
 | 
						|
        self, cookies: CookieTypes = None
 | 
						|
    ) -> None:
 | 
						|
        cookies = Cookies(cookies)
 | 
						|
        for cookie in cookies.jar:
 | 
						|
            self.jar.set_cookie(cookie)
 | 
						|
 | 
						|
    def as_header(self, request: Request) -> dict[str, str]:
 | 
						|
        urllib_request = self._CookieCompatRequest(request)
 | 
						|
        self.jar.add_cookie_header(urllib_request)
 | 
						|
        return urllib_request.added_headers
 | 
						|
 | 
						|
    def __setitem__(self, name: str, value: str) -> None:
 | 
						|
        return self.set(name, value)
 | 
						|
 | 
						|
    def __getitem__(self, name: str) -> str:
 | 
						|
        value = self.get(name)
 | 
						|
        if value is None:
 | 
						|
            raise KeyError(name)
 | 
						|
        return value
 | 
						|
 | 
						|
    def __delitem__(self, name: str) -> None:
 | 
						|
        return self.delete(name)
 | 
						|
 | 
						|
    def __len__(self) -> int:
 | 
						|
        return len(self.jar)
 | 
						|
 | 
						|
    def __iter__(self) -> Iterator[Cookie]:
 | 
						|
        return iter(self.jar)
 | 
						|
 | 
						|
    def __repr__(self) -> str:
 | 
						|
        cookies_repr = ", ".join(
 | 
						|
            f"Cookie({cookie.name}={cookie.value} for {cookie.domain})"
 | 
						|
            for cookie in self.jar
 | 
						|
        )
 | 
						|
        return f"{self.__class__.__name__}({cookies_repr})"
 | 
						|
 | 
						|
    class _CookieCompatRequest(urllib.request.Request):
 | 
						|
        def __init__(self, request: Request) -> None:
 | 
						|
            super().__init__(
 | 
						|
                url=str(request.url),
 | 
						|
                headers=dict(request.headers),
 | 
						|
                method=request.method,
 | 
						|
            )
 | 
						|
            self.request = request
 | 
						|
            self.added_headers: dict[str, str] = {}
 | 
						|
 | 
						|
        def add_unredirected_header(  # pyright: ignore[reportIncompatibleMethodOverride]
 | 
						|
            self, key: str, value: str
 | 
						|
        ) -> None:
 | 
						|
            super().add_unredirected_header(key, value)
 | 
						|
            self.added_headers[key] = value
 | 
						|
 | 
						|
 | 
						|
@dataclass
 | 
						|
class HTTPServerSetup:
 | 
						|
    """HTTP 服务器路由配置。"""
 | 
						|
 | 
						|
    path: URL  # path should not be absolute, check it by URL.is_absolute() == False
 | 
						|
    method: str
 | 
						|
    name: str
 | 
						|
    handle_func: Callable[[Request], Awaitable[Response]]
 | 
						|
 | 
						|
 | 
						|
@dataclass
 | 
						|
class WebSocketServerSetup:
 | 
						|
    """WebSocket 服务器路由配置。"""
 | 
						|
 | 
						|
    path: URL  # path should not be absolute, check it by URL.is_absolute() == False
 | 
						|
    name: str
 | 
						|
    handle_func: Callable[[WebSocket], Awaitable[Any]]
 |