💥 Remove: 移除 Python 3.9 支持 (#3860)

This commit is contained in:
呵呵です
2026-02-18 00:11:36 +08:00
committed by GitHub
parent f719a6b41b
commit 63cde5da77
56 changed files with 603 additions and 1144 deletions

View File

@@ -1,6 +1,6 @@
import abc
from functools import partial
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Protocol, Union
from typing import TYPE_CHECKING, Any, ClassVar, Protocol
import anyio
from exceptiongroup import BaseExceptionGroup, catch
@@ -77,7 +77,7 @@ class Bot(abc.ABC):
result: Any = None
skip_calling_api: bool = False
exception: Optional[Exception] = None
exception: Exception | None = None
if self._calling_api_hook:
logger.debug("Running CallingAPI hooks...")
@@ -180,7 +180,7 @@ class Bot(abc.ABC):
async def send(
self,
event: "Event",
message: Union[str, "Message", "MessageSegment"],
message: "str | Message | MessageSegment",
**kwargs: Any,
) -> Any:
"""调用机器人基础发送消息接口

View File

@@ -5,11 +5,9 @@ from dataclasses import asdict, dataclass, field
from typing import ( # noqa: UP035
Any,
Generic,
Optional,
SupportsIndex,
Type,
TypeVar,
Union,
overload,
)
from typing_extensions import Self
@@ -51,10 +49,10 @@ class MessageSegment(abc.ABC, Generic[TM]):
) -> bool:
return not self == other
def __add__(self, other: Union[str, Self, Iterable[Self]]) -> TM:
def __add__(self, other: str | Self | Iterable[Self]) -> TM:
return self.get_message_class()(self) + other
def __radd__(self, other: Union[str, Self, Iterable[Self]]) -> TM:
def __radd__(self, other: str | Self | Iterable[Self]) -> TM:
return self.get_message_class()(other) + self
@classmethod
@@ -87,7 +85,7 @@ class MessageSegment(abc.ABC, Generic[TM]):
def items(self):
return asdict(self).items()
def join(self, iterable: Iterable[Union[Self, TM]]) -> TM:
def join(self, iterable: Iterable[Self | TM]) -> TM:
return self.get_message_class()(self).join(iterable)
def copy(self) -> Self:
@@ -109,7 +107,7 @@ class Message(list[TMS], abc.ABC):
def __init__(
self,
message: Union[str, None, Iterable[TMS], TMS] = None,
message: str | None | Iterable[TMS] | TMS = None,
):
super().__init__()
if message is None:
@@ -124,7 +122,7 @@ class Message(list[TMS], abc.ABC):
self.extend(self._construct(message)) # pragma: no cover
@classmethod
def template(cls, format_string: Union[str, TM]) -> MessageTemplate[Self]:
def template(cls, format_string: str | TM) -> MessageTemplate[Self]:
"""创建消息模板。
用法和 `str.format` 大致相同,支持以 `Message` 对象作为消息模板并输出消息对象。
@@ -177,17 +175,17 @@ class Message(list[TMS], abc.ABC):
raise NotImplementedError
def __add__( # pyright: ignore[reportIncompatibleMethodOverride]
self, other: Union[str, TMS, Iterable[TMS]]
self, other: str | TMS | Iterable[TMS]
) -> Self:
result = self.copy()
result += other
return result
def __radd__(self, other: Union[str, TMS, Iterable[TMS]]) -> Self:
def __radd__(self, other: str | TMS | Iterable[TMS]) -> Self:
result = self.__class__(other)
return result + self
def __iadd__(self, other: Union[str, TMS, Iterable[TMS]]) -> Self:
def __iadd__(self, other: str | TMS | Iterable[TMS]) -> Self:
if isinstance(other, str):
self.extend(self._construct(other))
elif isinstance(other, MessageSegment):
@@ -255,14 +253,8 @@ class Message(list[TMS], abc.ABC):
def __getitem__( # pyright: ignore[reportIncompatibleMethodOverride]
self,
args: Union[
str,
tuple[str, int],
tuple[str, slice],
int,
slice,
],
) -> Union[TMS, Self]:
args: str | tuple[str, int] | tuple[str, slice] | int | slice,
) -> TMS | Self:
arg1, arg2 = args if isinstance(args, tuple) else (args, None)
if isinstance(arg1, int) and arg2 is None:
return super().__getitem__(arg1)
@@ -278,7 +270,7 @@ class Message(list[TMS], abc.ABC):
raise ValueError("Incorrect arguments to slice") # pragma: no cover
def __contains__( # pyright: ignore[reportIncompatibleMethodOverride]
self, value: Union[TMS, str]
self, value: TMS | str
) -> bool:
"""检查消息段是否存在
@@ -291,11 +283,11 @@ class Message(list[TMS], abc.ABC):
return next((seg for seg in self if seg.type == value), None) is not None
return super().__contains__(value)
def has(self, value: Union[TMS, str]) -> bool:
def has(self, value: TMS | str) -> bool:
"""{ref}``__contains__` <nonebot.adapters.Message.__contains__>` 相同"""
return value in self
def index(self, value: Union[TMS, str], *args: SupportsIndex) -> int:
def index(self, value: TMS | str, *args: SupportsIndex) -> int:
"""索引消息段
参数:
@@ -315,7 +307,7 @@ class Message(list[TMS], abc.ABC):
return super().index(first_segment, *args)
return super().index(value, *args)
def get(self, type_: str, count: Optional[int] = None) -> Self:
def get(self, type_: str, count: int | None = None) -> Self:
"""获取指定类型的消息段
参数:
@@ -339,7 +331,7 @@ class Message(list[TMS], abc.ABC):
filtered.append(seg)
return filtered
def count(self, value: Union[TMS, str]) -> int:
def count(self, value: TMS | str) -> int:
"""计算指定消息段的个数
参数:
@@ -350,7 +342,7 @@ class Message(list[TMS], abc.ABC):
"""
return len(self[value]) if isinstance(value, str) else super().count(value)
def only(self, value: Union[TMS, str]) -> bool:
def only(self, value: TMS | str) -> bool:
"""检查消息中是否仅包含指定消息段
参数:
@@ -364,7 +356,7 @@ class Message(list[TMS], abc.ABC):
return all(seg == value for seg in self)
def append( # pyright: ignore[reportIncompatibleMethodOverride]
self, obj: Union[str, TMS]
self, obj: str | TMS
) -> Self:
"""添加一个消息段到消息数组末尾。
@@ -380,7 +372,7 @@ class Message(list[TMS], abc.ABC):
return self
def extend( # pyright: ignore[reportIncompatibleMethodOverride]
self, obj: Union[Self, Iterable[TMS]]
self, obj: Self | Iterable[TMS]
) -> Self:
"""拼接一个消息数组或多个消息段到消息数组末尾。
@@ -391,7 +383,7 @@ class Message(list[TMS], abc.ABC):
self.append(segment)
return self
def join(self, iterable: Iterable[Union[TMS, Self]]) -> Self:
def join(self, iterable: Iterable[TMS | Self]) -> Self:
"""将多个消息连接并将自身作为分割
参数:

View File

@@ -1,19 +1,16 @@
from _string import formatter_field_name_split # type: ignore
from collections.abc import Mapping, Sequence
from collections.abc import Callable, Mapping, Sequence
import functools
from string import Formatter
from typing import (
TYPE_CHECKING,
Any,
Callable,
Generic,
Optional,
TypeAlias,
TypeVar,
Union,
cast,
overload,
)
from typing_extensions import TypeAlias
if TYPE_CHECKING:
from .message import Message, MessageSegment
@@ -50,15 +47,15 @@ class MessageTemplate(Formatter, Generic[TF]):
@overload
def __init__(
self: "MessageTemplate[TM]",
template: Union[str, TM],
template: str | TM,
factory: type[TM],
private_getattr: bool = False,
) -> None: ...
def __init__(
self,
template: Union[str, TM],
factory: Union[type[str], type[TM]] = str,
template: str | TM,
factory: type[str] | type[TM] = str,
private_getattr: bool = False,
) -> None:
self.template: TF = template # type: ignore
@@ -70,7 +67,7 @@ class MessageTemplate(Formatter, Generic[TF]):
return f"MessageTemplate({self.template!r}, factory={self.factory!r})"
def add_format_spec(
self, spec: FormatSpecFunc_T, name: Optional[str] = None
self, spec: FormatSpecFunc_T, name: str | None = None
) -> FormatSpecFunc_T:
name = name or spec.__name__
if name in self.format_specs:
@@ -126,7 +123,7 @@ class MessageTemplate(Formatter, Generic[TF]):
format_string: str,
args: Sequence[Any],
kwargs: Mapping[str, Any],
used_args: set[Union[int, str]],
used_args: set[int | str],
auto_arg_index: int = 0,
) -> tuple[TF, int]:
results: list[Any] = [self.factory()]
@@ -180,7 +177,7 @@ class MessageTemplate(Formatter, Generic[TF]):
def get_field(
self, field_name: str, args: Sequence[Any], kwargs: Mapping[str, Any]
) -> tuple[Any, Union[int, str]]:
) -> tuple[Any, int | str]:
first, rest = formatter_field_name_split(field_name)
obj = self.get_value(first, args, kwargs)
@@ -192,7 +189,7 @@ class MessageTemplate(Formatter, Generic[TF]):
return obj, first
def format_field(self, value: Any, format_spec: str) -> Any:
formatter: Optional[FormatSpecFunc] = self.format_specs.get(format_spec)
formatter: FormatSpecFunc | None = self.format_specs.get(format_spec)
if formatter is None and not issubclass(self.factory, str):
segment_class: type["MessageSegment"] = self.factory.get_segment_class()
method = getattr(segment_class, format_spec, None)

View File

@@ -1,7 +1,6 @@
from collections.abc import Awaitable, Iterable
from collections.abc import Awaitable, Callable, Iterable
from types import TracebackType
from typing import Any, Callable, Optional, Union, cast
from typing_extensions import TypeAlias
from typing import Any, TypeAlias, cast
import anyio
from anyio.abc import TaskGroup
@@ -11,12 +10,12 @@ from nonebot.utils import is_coroutine_callable, run_sync
SYNC_LIFESPAN_FUNC: TypeAlias = Callable[[], Any]
ASYNC_LIFESPAN_FUNC: TypeAlias = Callable[[], Awaitable[Any]]
LIFESPAN_FUNC: TypeAlias = Union[SYNC_LIFESPAN_FUNC, ASYNC_LIFESPAN_FUNC]
LIFESPAN_FUNC: TypeAlias = SYNC_LIFESPAN_FUNC | ASYNC_LIFESPAN_FUNC
class Lifespan:
def __init__(self) -> None:
self._task_group: Optional[TaskGroup] = None
self._task_group: TaskGroup | None = None
self._startup_funcs: list[LIFESPAN_FUNC] = []
self._ready_funcs: list[LIFESPAN_FUNC] = []
@@ -72,9 +71,9 @@ class Lifespan:
async def shutdown(
self,
*,
exc_type: Optional[type[BaseException]] = None,
exc_val: Optional[BaseException] = None,
exc_tb: Optional[TracebackType] = None,
exc_type: type[BaseException] | None = None,
exc_val: BaseException | None = None,
exc_tb: TracebackType | None = None,
) -> None:
if self._shutdown_funcs:
# reverse shutdown funcs to ensure stack order
@@ -93,8 +92,8 @@ class Lifespan:
async def __aexit__(
self,
exc_type: Optional[type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
await self.shutdown(exc_type=exc_type, exc_val=exc_val, exc_tb=exc_tb)

View File

@@ -2,8 +2,8 @@ import abc
from collections.abc import AsyncGenerator
from contextlib import AsyncExitStack, asynccontextmanager
from types import TracebackType
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union
from typing_extensions import Self, TypeAlias
from typing import TYPE_CHECKING, Any, ClassVar, TypeAlias
from typing_extensions import Self
from anyio import CancelScope, create_task_group
from anyio.abc import TaskGroup
@@ -245,9 +245,9 @@ class HTTPClientSession(abc.ABC):
params: QueryTypes = None,
headers: HeaderTypes = None,
cookies: CookieTypes = None,
version: Union[str, HTTPVersion] = HTTPVersion.H11,
version: str | HTTPVersion = HTTPVersion.H11,
timeout: TimeoutTypes = None,
proxy: Optional[str] = None,
proxy: str | None = None,
):
raise NotImplementedError
@@ -283,9 +283,9 @@ class HTTPClientSession(abc.ABC):
async def __aexit__(
self,
exc_type: Optional[type[BaseException]],
exc: Optional[BaseException],
tb: Optional[TracebackType],
exc_type: type[BaseException] | None,
exc: BaseException | None,
tb: TracebackType | None,
) -> None:
await self.close()
@@ -315,9 +315,9 @@ class HTTPClientMixin(ForwardMixin):
params: QueryTypes = None,
headers: HeaderTypes = None,
cookies: CookieTypes = None,
version: Union[str, HTTPVersion] = HTTPVersion.H11,
version: str | HTTPVersion = HTTPVersion.H11,
timeout: TimeoutTypes = None,
proxy: Optional[str] = None,
proxy: str | None = None,
) -> HTTPClientSession:
"""获取一个 HTTP 会话"""
raise NotImplementedError

View File

@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, TypeVar, Union, overload
from typing import TYPE_CHECKING, TypeVar, overload
from .abstract import Driver, Mixin
@@ -21,7 +21,7 @@ def combine_driver(
def combine_driver(
driver: type[D], *mixins: type[Mixin]
) -> Union[type[D], type["CombinedDriver"]]:
) -> type[D] | type["CombinedDriver"]:
"""将一个驱动器和多个混入类合并。"""
# check first
if not issubclass(driver, Driver):

View File

@@ -1,48 +1,51 @@
import abc
from collections.abc import Awaitable, Iterator, Mapping, MutableMapping
from collections.abc import Awaitable, Callable, Iterator, Mapping, MutableMapping
from dataclasses import dataclass
from enum import Enum
from http.cookiejar import Cookie, CookieJar
from typing import IO, Any, Callable, Optional, Union
from typing_extensions import TypeAlias
from typing import IO, Any, TypeAlias
import urllib.request
from multidict import CIMultiDict
from yarl import URL as URL
RawURL: TypeAlias = tuple[bytes, bytes, Optional[int], bytes]
SimpleQuery: TypeAlias = Union[str, int, float]
QueryVariable: TypeAlias = Union[SimpleQuery, list[SimpleQuery]]
QueryTypes: TypeAlias = Union[
None, str, Mapping[str, QueryVariable], list[tuple[str, SimpleQuery]]
]
@dataclass
class Timeout:
"""Request 超时配置。"""
HeaderTypes: TypeAlias = Union[
None,
CIMultiDict[str],
dict[str, str],
list[tuple[str, str]],
]
total: float | None = None
connect: float | None = None
read: float | None = None
CookieTypes: TypeAlias = Union[
None, "Cookies", CookieJar, dict[str, str], list[tuple[str, str]]
]
ContentTypes: TypeAlias = Union[str, bytes, None]
DataTypes: TypeAlias = Union[dict, None]
FileContent: TypeAlias = Union[IO[bytes], bytes]
FileType: TypeAlias = tuple[Optional[str], FileContent, Optional[str]]
FileTypes: TypeAlias = Union[
# file (or bytes)
FileContent,
# (filename, file (or bytes))
tuple[Optional[str], FileContent],
# (filename, file (or bytes), content_type)
FileType,
]
FilesTypes: TypeAlias = Union[dict[str, FileTypes], list[tuple[str, FileTypes]], None]
TimeoutTypes: TypeAlias = Union[float, "Timeout", None]
RawURL: TypeAlias = tuple[bytes, bytes, int | None, bytes]
SimpleQuery: TypeAlias = str | int | float
QueryVariable: TypeAlias = SimpleQuery | list[SimpleQuery]
QueryTypes: TypeAlias = (
None | str | Mapping[str, QueryVariable] | list[tuple[str, SimpleQuery]]
)
HeaderTypes: TypeAlias = (
None | CIMultiDict[str] | dict[str, str] | list[tuple[str, str]]
)
CookieTypes: TypeAlias = (
"None | Cookies | CookieJar | dict[str, str] | list[tuple[str, str]]"
)
ContentTypes: TypeAlias = str | bytes | None
DataTypes: TypeAlias = dict | None
FileContent: TypeAlias = IO[bytes] | bytes
FileType: TypeAlias = tuple[str | None, FileContent, str | None]
FileTypes: TypeAlias = (
FileContent # file (or bytes)
| tuple[str | None, FileContent] # (filename, file (or bytes))
| FileType # (filename, file (or bytes), content_type)
)
FilesTypes: TypeAlias = dict[str, FileTypes] | list[tuple[str, FileTypes]] | None
TimeoutTypes: TypeAlias = float | Timeout | None
class HTTPVersion(Enum):
@@ -51,20 +54,11 @@ class HTTPVersion(Enum):
H2 = "2"
@dataclass
class Timeout:
"""Request 超时配置。"""
total: Optional[float] = None
connect: Optional[float] = None
read: Optional[float] = None
class Request:
def __init__(
self,
method: Union[str, bytes],
url: Union["URL", str, RawURL],
method: str | bytes,
url: "URL | str | RawURL",
*,
params: QueryTypes = None,
headers: HeaderTypes = None,
@@ -73,9 +67,9 @@ class Request:
data: DataTypes = None,
json: Any = None,
files: FilesTypes = None,
version: Union[str, HTTPVersion] = HTTPVersion.H11,
version: str | HTTPVersion = HTTPVersion.H11,
timeout: TimeoutTypes = None,
proxy: Optional[str] = None,
proxy: str | None = None,
):
# method
self.method: str = (
@@ -88,7 +82,7 @@ class Request:
# timeout
self.timeout: TimeoutTypes = timeout
# proxy
self.proxy: Optional[str] = proxy
self.proxy: str | None = proxy
# url
if isinstance(url, tuple):
@@ -117,7 +111,7 @@ class Request:
self.content: ContentTypes = content
self.data: DataTypes = data
self.json: Any = json
self.files: Optional[list[tuple[str, FileType]]] = None
self.files: list[tuple[str, FileType]] | None = None
if files:
self.files = []
files_ = files.items() if isinstance(files, dict) else files
@@ -140,7 +134,7 @@ class Response:
*,
headers: HeaderTypes = None,
content: ContentTypes = None,
request: Optional[Request] = None,
request: Request | None = None,
):
# status code
self.status_code: int = status_code
@@ -153,7 +147,7 @@ class Response:
self.content: ContentTypes = content
# request
self.request: Optional[Request] = request
self.request: Request | None = request
def __repr__(self) -> str:
return f"{self.__class__.__name__}(status_code={self.status_code!r})"
@@ -183,7 +177,7 @@ class WebSocket(abc.ABC):
raise NotImplementedError
@abc.abstractmethod
async def receive(self) -> Union[str, bytes]:
async def receive(self) -> str | bytes:
"""接收一条 WebSocket text/bytes 信息"""
raise NotImplementedError
@@ -197,7 +191,7 @@ class WebSocket(abc.ABC):
"""接收一条 WebSocket binary 信息"""
raise NotImplementedError
async def send(self, data: Union[str, bytes]) -> None:
async def send(self, data: str | bytes) -> None:
"""发送一条 WebSocket text/bytes 信息"""
if isinstance(data, str):
await self.send_text(data)
@@ -258,11 +252,11 @@ class Cookies(MutableMapping):
def get( # pyright: ignore[reportIncompatibleMethodOverride]
self,
name: str,
default: Optional[str] = None,
domain: Optional[str] = None,
path: Optional[str] = None,
) -> Optional[str]:
value: Optional[str] = None
default: str | None = None,
domain: str | None = None,
path: str | None = None,
) -> str | None:
value: str | None = None
for cookie in self.jar:
if (
cookie.name == name
@@ -277,7 +271,7 @@ class Cookies(MutableMapping):
return default if value is None else value
def delete(
self, name: str, domain: Optional[str] = None, path: Optional[str] = None
self, name: str, domain: str | None = None, path: str | None = None
) -> None:
if domain is not None and path is not None:
return self.jar.clear(domain, path, name)
@@ -293,7 +287,7 @@ class Cookies(MutableMapping):
for cookie in remove:
self.jar.clear(cookie.domain, cookie.path, cookie.name)
def clear(self, domain: Optional[str] = None, path: Optional[str] = None) -> None:
def clear(self, domain: str | None = None, path: str | None = None) -> None:
self.jar.clear(domain, path)
def update( # pyright: ignore[reportIncompatibleMethodOverride]

View File

@@ -1,5 +1,5 @@
from collections.abc import ItemsView, Iterator, KeysView, MutableMapping, ValuesView
from typing import TYPE_CHECKING, Optional, TypeVar, Union, overload
from typing import TYPE_CHECKING, TypeVar, overload
from .provider import DEFAULT_PROVIDER_CLASS, MatcherProvider
@@ -52,7 +52,7 @@ class MatcherManager(MutableMapping[int, list[type["Matcher"]]]):
return self.provider.items()
@overload
def get(self, key: int) -> Optional[list[type["Matcher"]]]: ...
def get(self, key: int) -> list[type["Matcher"]] | None: ...
@overload
def get(
@@ -60,11 +60,11 @@ class MatcherManager(MutableMapping[int, list[type["Matcher"]]]):
) -> list[type["Matcher"]]: ...
@overload
def get(self, key: int, default: T) -> Union[list[type["Matcher"]], T]: ...
def get(self, key: int, default: T) -> list[type["Matcher"]] | T: ...
def get(
self, key: int, default: Optional[T] = None
) -> Optional[Union[list[type["Matcher"]], T]]:
self, key: int, default: T | None = None
) -> list[type["Matcher"]] | T | None:
return self.provider.get(key, default)
def pop( # pyright: ignore[reportIncompatibleMethodOverride]

View File

@@ -13,10 +13,8 @@ from typing import ( # noqa: UP035
Callable,
ClassVar,
NoReturn,
Optional,
Type,
TypeVar,
Union,
overload,
)
from typing_extensions import Self
@@ -87,15 +85,15 @@ current_handler: ContextVar[Dependent[Any]] = ContextVar("current_handler")
class MatcherSource:
"""Matcher 源代码上下文信息"""
plugin_id: Optional[str] = None
plugin_id: str | None = None
"""事件响应器所在插件标识符"""
module_name: Optional[str] = None
module_name: str | None = None
"""事件响应器所在插件模块的路径名"""
lineno: Optional[int] = None
lineno: int | None = None
"""事件响应器所在行号"""
@property
def plugin(self) -> Optional["Plugin"]:
def plugin(self) -> "Plugin | None":
"""事件响应器所在插件"""
from nonebot.plugin import get_plugin
@@ -103,17 +101,17 @@ class MatcherSource:
return get_plugin(self.plugin_id)
@property
def plugin_name(self) -> Optional[str]:
def plugin_name(self) -> str | None:
"""事件响应器所在插件名"""
return self.plugin and self.plugin.name
@property
def module(self) -> Optional[ModuleType]:
def module(self) -> ModuleType | None:
if self.module_name is not None:
return sys.modules.get(self.module_name)
@property
def file(self) -> Optional[Path]:
def file(self) -> Path | None:
if self.module is not None and (file := inspect.getsourcefile(self.module)):
return Path(file).absolute()
@@ -121,8 +119,8 @@ class MatcherSource:
class MatcherMeta(type):
if TYPE_CHECKING:
type: str
_source: Optional[MatcherSource]
module_name: Optional[str]
_source: MatcherSource | None
module_name: str | None
def __repr__(self) -> str:
return (
@@ -140,7 +138,7 @@ class MatcherMeta(type):
class Matcher(metaclass=MatcherMeta):
"""事件响应器类"""
_source: ClassVar[Optional[MatcherSource]] = None
_source: ClassVar[MatcherSource | None] = None
type: ClassVar[str] = ""
"""事件响应器类型"""
@@ -156,15 +154,15 @@ class Matcher(metaclass=MatcherMeta):
"""事件响应器是否阻止事件传播"""
temp: ClassVar[bool] = False
"""事件响应器是否为临时"""
expire_time: ClassVar[Optional[datetime]] = None
expire_time: ClassVar[datetime | None] = None
"""事件响应器过期时间点"""
_default_state: ClassVar[T_State] = {}
"""事件响应器默认状态"""
_default_type_updater: ClassVar[Optional[Dependent[str]]] = None
_default_type_updater: ClassVar[Dependent[str] | None] = None
"""事件响应器类型更新函数"""
_default_permission_updater: ClassVar[Optional[Dependent[Permission]]] = None
_default_permission_updater: ClassVar[Dependent[Permission] | None] = None
"""事件响应器权限更新函数"""
HANDLER_PARAM_TYPES: ClassVar[tuple[Type[Param], ...]] = ( # noqa: UP006
@@ -197,22 +195,22 @@ class Matcher(metaclass=MatcherMeta):
def new(
cls,
type_: str = "",
rule: Optional[Rule] = None,
permission: Optional[Permission] = None,
handlers: Optional[list[Union[T_Handler, Dependent[Any]]]] = None,
rule: Rule | None = None,
permission: Permission | None = None,
handlers: list[T_Handler | Dependent[Any]] | None = None,
temp: bool = False,
priority: int = 1,
block: bool = False,
*,
plugin: Optional["Plugin"] = None,
module: Optional[ModuleType] = None,
source: Optional[MatcherSource] = None,
expire_time: Optional[Union[datetime, timedelta]] = None,
default_state: Optional[T_State] = None,
default_type_updater: Optional[Union[T_TypeUpdater, Dependent[str]]] = None,
default_permission_updater: Optional[
Union[T_PermissionUpdater, Dependent[Permission]]
] = None,
plugin: "Plugin | None" = None,
module: ModuleType | None = None,
source: MatcherSource | None = None,
expire_time: datetime | timedelta | None = None,
default_state: T_State | None = None,
default_type_updater: T_TypeUpdater | Dependent[str] | None = None,
default_permission_updater: T_PermissionUpdater
| Dependent[Permission]
| None = None,
) -> Type[Self]: # noqa: UP006
"""
创建一个新的事件响应器,并存储至 `matchers <#matchers>`_
@@ -332,27 +330,27 @@ class Matcher(metaclass=MatcherMeta):
matchers[cls.priority].remove(cls)
@classproperty
def plugin(cls) -> Optional["Plugin"]:
def plugin(cls) -> "Plugin | None":
"""事件响应器所在插件"""
return cls._source and cls._source.plugin
@classproperty
def plugin_id(cls) -> Optional[str]:
def plugin_id(cls) -> str | None:
"""事件响应器所在插件标识符"""
return cls._source and cls._source.plugin_id
@classproperty
def plugin_name(cls) -> Optional[str]:
def plugin_name(cls) -> str | None:
"""事件响应器所在插件名"""
return cls._source and cls._source.plugin_name
@classproperty
def module(cls) -> Optional[ModuleType]:
def module(cls) -> ModuleType | None:
"""事件响应器所在插件模块"""
return cls._source and cls._source.module
@classproperty
def module_name(cls) -> Optional[str]:
def module_name(cls) -> str | None:
"""事件响应器所在插件模块路径"""
return cls._source and cls._source.module_name
@@ -361,8 +359,8 @@ class Matcher(metaclass=MatcherMeta):
cls,
bot: Bot,
event: Event,
stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[T_DependencyCache] = None,
stack: AsyncExitStack | None = None,
dependency_cache: T_DependencyCache | None = None,
) -> bool:
"""检查是否满足触发权限
@@ -386,8 +384,8 @@ class Matcher(metaclass=MatcherMeta):
bot: Bot,
event: Event,
state: T_State,
stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[T_DependencyCache] = None,
stack: AsyncExitStack | None = None,
dependency_cache: T_DependencyCache | None = None,
) -> bool:
"""检查是否满足匹配规则
@@ -432,7 +430,7 @@ class Matcher(metaclass=MatcherMeta):
@classmethod
def append_handler(
cls, handler: T_Handler, parameterless: Optional[Iterable[Any]] = None
cls, handler: T_Handler, parameterless: Iterable[Any] | None = None
) -> Dependent[Any]:
handler_ = Dependent[Any].parse(
call=handler,
@@ -444,7 +442,7 @@ class Matcher(metaclass=MatcherMeta):
@classmethod
def handle(
cls, parameterless: Optional[Iterable[Any]] = None
cls, parameterless: Iterable[Any] | None = None
) -> Callable[[T_Handler], T_Handler]:
"""装饰一个函数来向事件响应器直接添加一个处理函数
@@ -460,7 +458,7 @@ class Matcher(metaclass=MatcherMeta):
@classmethod
def receive(
cls, id: str = "", parameterless: Optional[Iterable[Any]] = None
cls, id: str = "", parameterless: Iterable[Any] | None = None
) -> Callable[[T_Handler], T_Handler]:
"""装饰一个函数来指示 NoneBot 在接收用户新的一条消息后继续运行该函数
@@ -503,8 +501,8 @@ class Matcher(metaclass=MatcherMeta):
def got(
cls,
key: str,
prompt: Optional[Union[str, Message, MessageSegment, MessageTemplate]] = None,
parameterless: Optional[Iterable[Any]] = None,
prompt: str | Message | MessageSegment | MessageTemplate | None = None,
parameterless: Iterable[Any] | None = None,
) -> Callable[[T_Handler], T_Handler]:
"""装饰一个函数来指示 NoneBot 获取一个参数 `key`
@@ -550,7 +548,7 @@ class Matcher(metaclass=MatcherMeta):
@classmethod
async def send(
cls,
message: Union[str, Message, MessageSegment, MessageTemplate],
message: str | Message | MessageSegment | MessageTemplate,
**kwargs: Any,
) -> Any:
"""发送一条消息给当前交互用户
@@ -572,7 +570,7 @@ class Matcher(metaclass=MatcherMeta):
@classmethod
async def finish(
cls,
message: Optional[Union[str, Message, MessageSegment, MessageTemplate]] = None,
message: str | Message | MessageSegment | MessageTemplate | None = None,
**kwargs,
) -> NoReturn:
"""发送一条消息给当前交互用户并结束当前事件响应器
@@ -589,7 +587,7 @@ class Matcher(metaclass=MatcherMeta):
@classmethod
async def pause(
cls,
prompt: Optional[Union[str, Message, MessageSegment, MessageTemplate]] = None,
prompt: str | Message | MessageSegment | MessageTemplate | None = None,
**kwargs,
) -> NoReturn:
"""发送一条消息给当前交互用户并暂停事件响应器,在接收用户新的一条消息后继续下一个处理函数
@@ -613,7 +611,7 @@ class Matcher(metaclass=MatcherMeta):
@classmethod
async def reject(
cls,
prompt: Optional[Union[str, Message, MessageSegment, MessageTemplate]] = None,
prompt: str | Message | MessageSegment | MessageTemplate | None = None,
**kwargs,
) -> NoReturn:
"""最近使用 `got` / `receive` 接收的消息不符合预期,
@@ -643,7 +641,7 @@ class Matcher(metaclass=MatcherMeta):
async def reject_arg(
cls,
key: str,
prompt: Optional[Union[str, Message, MessageSegment, MessageTemplate]] = None,
prompt: str | Message | MessageSegment | MessageTemplate | None = None,
**kwargs,
) -> NoReturn:
"""最近使用 `got` 接收的消息不符合预期,
@@ -668,7 +666,7 @@ class Matcher(metaclass=MatcherMeta):
async def reject_receive(
cls,
id: str = "",
prompt: Optional[Union[str, Message, MessageSegment, MessageTemplate]] = None,
prompt: str | Message | MessageSegment | MessageTemplate | None = None,
**kwargs,
) -> NoReturn:
"""最近使用 `receive` 接收的消息不符合预期,
@@ -698,14 +696,12 @@ class Matcher(metaclass=MatcherMeta):
raise SkippedException
@overload
def get_receive(self, id: str) -> Union[Event, None]: ...
def get_receive(self, id: str) -> Event | None: ...
@overload
def get_receive(self, id: str, default: T) -> Union[Event, T]: ...
def get_receive(self, id: str, default: T) -> Event | T: ...
def get_receive(
self, id: str, default: Optional[T] = None
) -> Optional[Union[Event, T]]:
def get_receive(self, id: str, default: T | None = None) -> Event | T | None:
"""获取一个 `receive` 事件
如果没有找到对应的事件,返回 `default` 值
@@ -718,14 +714,12 @@ class Matcher(metaclass=MatcherMeta):
self.state[LAST_RECEIVE_KEY] = event
@overload
def get_last_receive(self) -> Union[Event, None]: ...
def get_last_receive(self) -> Event | None: ...
@overload
def get_last_receive(self, default: T) -> Union[Event, T]: ...
def get_last_receive(self, default: T) -> Event | T: ...
def get_last_receive(
self, default: Optional[T] = None
) -> Optional[Union[Event, T]]:
def get_last_receive(self, default: T | None = None) -> Event | T | None:
"""获取最近一次 `receive` 事件
如果没有事件,返回 `default` 值
@@ -733,14 +727,12 @@ class Matcher(metaclass=MatcherMeta):
return self.state.get(LAST_RECEIVE_KEY, default)
@overload
def get_arg(self, key: str) -> Union[Message, None]: ...
def get_arg(self, key: str) -> Message | None: ...
@overload
def get_arg(self, key: str, default: T) -> Union[Message, T]: ...
def get_arg(self, key: str, default: T) -> Message | T: ...
def get_arg(
self, key: str, default: Optional[T] = None
) -> Optional[Union[Message, T]]:
def get_arg(self, key: str, default: T | None = None) -> Message | T | None:
"""获取一个 `got` 消息
如果没有找到对应的消息,返回 `default` 值
@@ -758,12 +750,12 @@ class Matcher(metaclass=MatcherMeta):
self.state[REJECT_TARGET] = target
@overload
def get_target(self) -> Union[str, None]: ...
def get_target(self) -> str | None: ...
@overload
def get_target(self, default: T) -> Union[str, T]: ...
def get_target(self, default: T) -> str | T: ...
def get_target(self, default: Optional[T] = None) -> Optional[Union[str, T]]:
def get_target(self, default: T | None = None) -> str | T | None:
return self.state.get(REJECT_TARGET, default)
def stop_propagation(self):
@@ -774,8 +766,8 @@ class Matcher(metaclass=MatcherMeta):
self,
bot: Bot,
event: Event,
stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[T_DependencyCache] = None,
stack: AsyncExitStack | None = None,
dependency_cache: T_DependencyCache | None = None,
) -> str:
updater = self.__class__._default_type_updater
return (
@@ -795,8 +787,8 @@ class Matcher(metaclass=MatcherMeta):
self,
bot: Bot,
event: Event,
stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[T_DependencyCache] = None,
stack: AsyncExitStack | None = None,
dependency_cache: T_DependencyCache | None = None,
) -> Permission:
if updater := self.__class__._default_permission_updater:
return await updater(
@@ -832,8 +824,8 @@ class Matcher(metaclass=MatcherMeta):
bot: Bot,
event: Event,
state: T_State,
stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[T_DependencyCache] = None,
stack: AsyncExitStack | None = None,
dependency_cache: T_DependencyCache | None = None,
):
logger.trace(
f"{self} run with incoming args: "
@@ -877,16 +869,14 @@ class Matcher(metaclass=MatcherMeta):
bot: Bot,
event: Event,
state: T_State,
stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[T_DependencyCache] = None,
stack: AsyncExitStack | None = None,
dependency_cache: T_DependencyCache | None = None,
):
exc: Optional[Union[FinishedException, RejectedException, PausedException]] = (
None
)
exc: FinishedException | RejectedException | PausedException | None = None
def _handle_special_exception(
exc_group: BaseExceptionGroup[
Union[FinishedException, RejectedException, PausedException]
FinishedException | RejectedException | PausedException
],
):
nonlocal exc

View File

@@ -1,3 +1,4 @@
from collections.abc import Callable
from contextlib import AsyncExitStack, asynccontextmanager, contextmanager
from enum import Enum
import inspect
@@ -5,13 +6,12 @@ from typing import (
TYPE_CHECKING,
Annotated,
Any,
Callable,
Literal,
Optional,
Union,
cast,
get_args,
get_origin,
)
from typing_extensions import Self, get_args, get_origin, override
from typing_extensions import Self, override
import anyio
from exceptiongroup import BaseExceptionGroup, catch
@@ -47,10 +47,10 @@ if TYPE_CHECKING:
class DependsInner:
def __init__(
self,
dependency: Optional[T_Handler] = None,
dependency: T_Handler | None = None,
*,
use_cache: bool = True,
validate: Union[bool, PydanticFieldInfo] = False,
validate: bool | PydanticFieldInfo = False,
) -> None:
self.dependency = dependency
self.use_cache = use_cache
@@ -64,10 +64,10 @@ class DependsInner:
def Depends(
dependency: Optional[T_Handler] = None,
dependency: T_Handler | None = None,
*,
use_cache: bool = True,
validate: Union[bool, PydanticFieldInfo] = False,
validate: bool | PydanticFieldInfo = False,
) -> Any:
"""子依赖装饰器
@@ -113,7 +113,7 @@ class DependencyCache:
def __init__(self):
self._state = CacheState.PENDING
self._result: Any = None
self._exception: Optional[BaseException] = None
self._exception: BaseException | None = None
self._waiter = anyio.Event()
def done(self) -> bool:
@@ -129,7 +129,7 @@ class DependencyCache:
raise self._exception
return self._result
def exception(self) -> Optional[BaseException]:
def exception(self) -> BaseException | None:
"""获取子依赖异常"""
if self._state != CacheState.FINISHED:
@@ -192,7 +192,7 @@ class DependParam(Param):
cls,
sub_dependent: Dependent[Any],
use_cache: bool,
validate: Union[bool, PydanticFieldInfo],
validate: bool | PydanticFieldInfo,
) -> Self:
return cls._inherit_construct(
validate if isinstance(validate, PydanticFieldInfo) else None,
@@ -205,7 +205,7 @@ class DependParam(Param):
@override
def _check_param(
cls, param: inspect.Parameter, allow_types: tuple[type[Param], ...]
) -> Optional[Self]:
) -> Self | None:
type_annotation, depends_inner = param.annotation, None
# extract type annotation and dependency from Annotated
if get_origin(param.annotation) is Annotated:
@@ -245,7 +245,7 @@ class DependParam(Param):
@override
def _check_parameterless(
cls, value: Any, allow_types: tuple[type[Param], ...]
) -> Optional["Param"]:
) -> "Param | None":
if isinstance(value, DependsInner):
assert value.dependency, "Dependency cannot be empty"
dependent = Dependent[Any].parse(
@@ -256,8 +256,8 @@ class DependParam(Param):
@override
async def _solve(
self,
stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[T_DependencyCache] = None,
stack: AsyncExitStack | None = None,
dependency_cache: T_DependencyCache | None = None,
**kwargs: Any,
) -> Any:
use_cache: bool = self.use_cache
@@ -267,7 +267,7 @@ class DependParam(Param):
call = cast(Callable[..., Any], sub_dependent.call)
# solve sub dependency with current cache
exc: Optional[BaseExceptionGroup[SkippedException]] = None
exc: BaseExceptionGroup[SkippedException] | None = None
def _handle_skipped(exc_group: BaseExceptionGroup[SkippedException]):
nonlocal exc
@@ -332,9 +332,7 @@ class BotParam(Param):
为保证兼容性,本注入还会解析名为 `bot` 且没有类型注解的参数。
"""
def __init__(
self, *args, checker: Optional[ModelField] = None, **kwargs: Any
) -> None:
def __init__(self, *args, checker: ModelField | None = None, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.checker = checker
@@ -349,12 +347,12 @@ class BotParam(Param):
@override
def _check_param(
cls, param: inspect.Parameter, allow_types: tuple[type[Param], ...]
) -> Optional[Self]:
) -> Self | None:
from nonebot.adapters import Bot
# param type is Bot(s) or subclass(es) of Bot or None
if generic_check_issubclass(param.annotation, Bot):
checker: Optional[ModelField] = None
checker: ModelField | None = None
if param.annotation is not Bot:
checker = ModelField.construct(
name=param.name, annotation=param.annotation, field_info=FieldInfo()
@@ -386,9 +384,7 @@ class EventParam(Param):
为保证兼容性,本注入还会解析名为 `event` 且没有类型注解的参数。
"""
def __init__(
self, *args, checker: Optional[ModelField] = None, **kwargs: Any
) -> None:
def __init__(self, *args, checker: ModelField | None = None, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.checker = checker
@@ -403,12 +399,12 @@ class EventParam(Param):
@override
def _check_param(
cls, param: inspect.Parameter, allow_types: tuple[type[Param], ...]
) -> Optional[Self]:
) -> Self | None:
from nonebot.adapters import Event
# param type is Event(s) or subclass(es) of Event or None
if generic_check_issubclass(param.annotation, Event):
checker: Optional[ModelField] = None
checker: ModelField | None = None
if param.annotation is not Event:
checker = ModelField.construct(
name=param.name, annotation=param.annotation, field_info=FieldInfo()
@@ -447,7 +443,7 @@ class StateParam(Param):
@override
def _check_param(
cls, param: inspect.Parameter, allow_types: tuple[type[Param], ...]
) -> Optional[Self]:
) -> Self | None:
# param type is T_State
if origin_is_annotated(
get_origin(param.annotation)
@@ -472,9 +468,7 @@ class MatcherParam(Param):
为保证兼容性,本注入还会解析名为 `matcher` 且没有类型注解的参数。
"""
def __init__(
self, *args, checker: Optional[ModelField] = None, **kwargs: Any
) -> None:
def __init__(self, *args, checker: ModelField | None = None, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.checker = checker
@@ -489,12 +483,12 @@ class MatcherParam(Param):
@override
def _check_param(
cls, param: inspect.Parameter, allow_types: tuple[type[Param], ...]
) -> Optional[Self]:
) -> Self | None:
from nonebot.matcher import Matcher
# param type is Matcher(s) or subclass(es) of Matcher or None
if generic_check_issubclass(param.annotation, Matcher):
checker: Optional[ModelField] = None
checker: ModelField | None = None
if param.annotation is not Matcher:
checker = ModelField.construct(
name=param.name, annotation=param.annotation, field_info=FieldInfo()
@@ -520,31 +514,31 @@ class MatcherParam(Param):
class ArgInner:
def __init__(
self, key: Optional[str], type: Literal["message", "str", "plaintext", "prompt"]
self, key: str | None, type: Literal["message", "str", "plaintext", "prompt"]
) -> None:
self.key: Optional[str] = key
self.key: str | None = key
self.type: Literal["message", "str", "plaintext", "prompt"] = type
def __repr__(self) -> str:
return f"ArgInner(key={self.key!r}, type={self.type!r})"
def Arg(key: Optional[str] = None) -> Any:
def Arg(key: str | None = None) -> Any:
"""Arg 参数消息"""
return ArgInner(key, "message")
def ArgStr(key: Optional[str] = None) -> str:
def ArgStr(key: str | None = None) -> str:
"""Arg 参数消息文本"""
return ArgInner(key, "str") # type: ignore
def ArgPlainText(key: Optional[str] = None) -> str:
def ArgPlainText(key: str | None = None) -> str:
"""Arg 参数消息纯文本"""
return ArgInner(key, "plaintext") # type: ignore
def ArgPromptResult(key: Optional[str] = None) -> Any:
def ArgPromptResult(key: str | None = None) -> Any:
"""`arg` prompt 发送结果"""
return ArgInner(key, "prompt")
@@ -576,7 +570,7 @@ class ArgParam(Param):
@override
def _check_param(
cls, param: inspect.Parameter, allow_types: tuple[type[Param], ...]
) -> Optional[Self]:
) -> Self | None:
if isinstance(param.default, ArgInner):
return cls(key=param.default.key or param.name, type=param.default.type)
elif get_origin(param.annotation) is Annotated:
@@ -598,18 +592,18 @@ class ArgParam(Param):
else:
raise ValueError(f"Unknown Arg type: {self.type}")
def _solve_message(self, matcher: "Matcher") -> Optional["Message"]:
def _solve_message(self, matcher: "Matcher") -> "Message | None":
return matcher.get_arg(self.key)
def _solve_str(self, matcher: "Matcher") -> Optional[str]:
def _solve_str(self, matcher: "Matcher") -> str | None:
message = matcher.get_arg(self.key)
return str(message) if message is not None else None
def _solve_plaintext(self, matcher: "Matcher") -> Optional[str]:
def _solve_plaintext(self, matcher: "Matcher") -> str | None:
message = matcher.get_arg(self.key)
return message.extract_plain_text() if message is not None else None
def _solve_prompt(self, matcher: "Matcher") -> Optional[Any]:
def _solve_prompt(self, matcher: "Matcher") -> Any | None:
return matcher.state.get(
REJECT_PROMPT_RESULT_KEY.format(key=ARG_KEY.format(key=self.key))
)
@@ -630,7 +624,7 @@ class ExceptionParam(Param):
@override
def _check_param(
cls, param: inspect.Parameter, allow_types: tuple[type[Param], ...]
) -> Optional[Self]:
) -> Self | None:
# param type is Exception(s) or subclass(es) of Exception or None
if generic_check_issubclass(param.annotation, Exception):
return cls()
@@ -639,7 +633,7 @@ class ExceptionParam(Param):
return cls()
@override
async def _solve(self, exception: Optional[Exception] = None, **kwargs: Any) -> Any:
async def _solve(self, exception: Exception | None = None, **kwargs: Any) -> Any:
return exception
@@ -658,7 +652,7 @@ class DefaultParam(Param):
@override
def _check_param(
cls, param: inspect.Parameter, allow_types: tuple[type[Param], ...]
) -> Optional[Self]:
) -> Self | None:
if param.default != param.empty:
return cls(default=param.default)

View File

@@ -1,5 +1,5 @@
from contextlib import AsyncExitStack
from typing import ClassVar, NoReturn, Optional, Union
from typing import ClassVar, NoReturn
from typing_extensions import Self
import anyio
@@ -38,7 +38,7 @@ class Permission:
DefaultParam,
]
def __init__(self, *checkers: Union[T_PermissionChecker, Dependent[bool]]) -> None:
def __init__(self, *checkers: T_PermissionChecker | Dependent[bool]) -> None:
self.checkers: set[Dependent[bool]] = {
(
checker
@@ -58,8 +58,8 @@ class Permission:
self,
bot: Bot,
event: Event,
stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[T_DependencyCache] = None,
stack: AsyncExitStack | None = None,
dependency_cache: T_DependencyCache | None = None,
) -> bool:
"""检查是否满足某个权限。
@@ -95,9 +95,7 @@ class Permission:
def __and__(self, other: object) -> NoReturn:
raise RuntimeError("And operation between Permissions is not allowed.")
def __or__(
self, other: Optional[Union["Permission", T_PermissionChecker]]
) -> "Permission":
def __or__(self, other: "Permission | T_PermissionChecker | None") -> "Permission":
if other is None:
return self
elif isinstance(other, Permission):
@@ -105,9 +103,7 @@ class Permission:
else:
return Permission(*self.checkers, other)
def __ror__(
self, other: Optional[Union["Permission", T_PermissionChecker]]
) -> "Permission":
def __ror__(self, other: "Permission | T_PermissionChecker | None") -> "Permission":
if other is None:
return self
elif isinstance(other, Permission):
@@ -126,9 +122,7 @@ class User:
__slots__ = ("perm", "users")
def __init__(
self, users: tuple[str, ...], perm: Optional[Permission] = None
) -> None:
def __init__(self, users: tuple[str, ...], perm: Permission | None = None) -> None:
self.users = users
self.perm = perm
@@ -149,7 +143,7 @@ class User:
)
@classmethod
def _clean_permission(cls, perm: Permission) -> Optional[Permission]:
def _clean_permission(cls, perm: Permission) -> Permission | None:
if len(perm.checkers) == 1 and isinstance(
user_perm := next(iter(perm.checkers)).call, cls
):
@@ -157,7 +151,7 @@ class User:
return perm
@classmethod
def from_event(cls, event: Event, perm: Optional[Permission] = None) -> Self:
def from_event(cls, event: Event, perm: Permission | None = None) -> Self:
"""从事件中获取会话 ID。
如果 `perm` 中仅有 `User` 类型的权限检查函数,则会去除原有的会话 ID 限制。
@@ -169,7 +163,7 @@ class User:
return cls((event.get_session_id(),), perm=perm and cls._clean_permission(perm))
@classmethod
def from_permission(cls, *users: str, perm: Optional[Permission] = None) -> Self:
def from_permission(cls, *users: str, perm: Permission | None = None) -> Self:
"""指定会话与权限。
如果 `perm` 中仅有 `User` 类型的权限检查函数,则会去除原有的会话 ID 限制。
@@ -181,7 +175,7 @@ class User:
return cls(users, perm=perm and cls._clean_permission(perm))
def USER(*users: str, perm: Optional[Permission] = None):
def USER(*users: str, perm: Permission | None = None):
"""匹配当前事件属于指定会话。
如果 `perm` 中仅有 `User` 类型的权限检查函数,则会去除原有检查函数的会话 ID 限制。

View File

@@ -1,5 +1,5 @@
from contextlib import AsyncExitStack
from typing import ClassVar, NoReturn, Optional, Union
from typing import ClassVar, NoReturn
import anyio
from exceptiongroup import BaseExceptionGroup, catch
@@ -38,7 +38,7 @@ class Rule:
DefaultParam,
]
def __init__(self, *checkers: Union[T_RuleChecker, Dependent[bool]]) -> None:
def __init__(self, *checkers: T_RuleChecker | Dependent[bool]) -> None:
self.checkers: set[Dependent[bool]] = {
(
checker
@@ -59,8 +59,8 @@ class Rule:
bot: Bot,
event: Event,
state: T_State,
stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[T_DependencyCache] = None,
stack: AsyncExitStack | None = None,
dependency_cache: T_DependencyCache | None = None,
) -> bool:
"""检查是否符合所有规则
@@ -101,7 +101,7 @@ class Rule:
return result
def __and__(self, other: Optional[Union["Rule", T_RuleChecker]]) -> "Rule":
def __and__(self, other: "Rule | T_RuleChecker | None") -> "Rule":
if other is None:
return self
elif isinstance(other, Rule):
@@ -109,7 +109,7 @@ class Rule:
else:
return Rule(*self.checkers, other)
def __rand__(self, other: Optional[Union["Rule", T_RuleChecker]]) -> "Rule":
def __rand__(self, other: "Rule | T_RuleChecker | None") -> "Rule":
if other is None:
return self
elif isinstance(other, Rule):