"""本模块包含了 NoneBot 的一些工具函数 FrontMatter: mdx: format: md sidebar_position: 8 description: nonebot.utils 模块 """ from collections import deque from collections.abc import ( AsyncGenerator, Callable, Coroutine, Generator, Mapping, Sequence, ) import contextlib from contextlib import AbstractContextManager, asynccontextmanager import dataclasses from functools import partial, wraps import importlib import inspect import json from pathlib import Path import re from typing import ( Any, Generic, TypeVar, get_args, get_origin, overload, ) from typing_extensions import ParamSpec, override import anyio import anyio.to_thread from exceptiongroup import BaseExceptionGroup, catch from pydantic import BaseModel from nonebot.log import logger from nonebot.typing import ( all_literal_values, is_none_type, origin_is_literal, origin_is_union, type_has_args, ) P = ParamSpec("P") R = TypeVar("R") T = TypeVar("T") K = TypeVar("K") V = TypeVar("V") E = TypeVar("E", bound=BaseException) def escape_tag(s: str) -> str: """用于记录带颜色日志时转义 `` 类型特殊标签 参考: [loguru color 标签](https://loguru.readthedocs.io/en/stable/api/logger.html#color) 参数: s: 需要转义的字符串 """ return re.sub(r"\s]*)>", r"\\\g<0>", s) def deep_update( mapping: dict[K, Any], *updating_mappings: dict[K, Any] ) -> dict[K, Any]: """深度更新合并字典""" updated_mapping = mapping.copy() for updating_mapping in updating_mappings: for k, v in updating_mapping.items(): if ( k in updated_mapping and isinstance(updated_mapping[k], dict) and isinstance(v, dict) ): updated_mapping[k] = deep_update(updated_mapping[k], v) else: updated_mapping[k] = v return updated_mapping def lenient_issubclass( cls: Any, class_or_tuple: type[Any] | tuple[type[Any], ...] ) -> bool: """检查 cls 是否是 class_or_tuple 中的一个类型子类并忽略类型错误。""" try: return isinstance(cls, type) and issubclass(cls, class_or_tuple) except TypeError: return False def generic_check_issubclass( cls: Any, class_or_tuple: type[Any] | tuple[type[Any], ...] ) -> bool: """检查 cls 是否是 class_or_tuple 中的一个类型子类。 特别的: - 如果 cls 是 `typing.TypeVar` 类型, 则会检查其 `__bound__` 或 `__constraints__` 是否是 class_or_tuple 中一个类型的子类或 None。 - 如果 cls 是 `typing.Union` 或 `types.UnionType` 类型, 则会检查其中的所有类型是否是 class_or_tuple 中一个类型的子类或 None。 - 如果 cls 是 `typing.Literal` 类型, 则会检查其中的所有值是否是 class_or_tuple 中一个类型的实例。 - 如果 cls 是 `typing.List`、`typing.Dict` 等泛型类型, 则会检查其原始类型是否是 class_or_tuple 中一个类型的子类。 """ # if the target is a TypeVar, we check it first if isinstance(cls, TypeVar): if cls.__constraints__: return all( is_none_type(type_) or generic_check_issubclass(type_, class_or_tuple) for type_ in cls.__constraints__ ) elif cls.__bound__: return generic_check_issubclass(cls.__bound__, class_or_tuple) return False # elif the target is not a generic type, we check it directly elif not type_has_args(cls): with contextlib.suppress(TypeError): return issubclass(cls, class_or_tuple) origin = get_origin(cls) if origin_is_union(origin): return all( is_none_type(type_) or generic_check_issubclass(type_, class_or_tuple) for type_ in get_args(cls) ) elif origin_is_literal(origin): return all( is_none_type(value) or isinstance(value, class_or_tuple) for value in all_literal_values(cls) ) # ensure generic List, Dict can be checked elif origin: # avoid class check error (typing.Final, typing.ClassVar, etc...) try: return issubclass(origin, class_or_tuple) except TypeError: return False return False def type_is_complex(type_: type[Any]) -> bool: """检查 type_ 是否是复杂类型""" origin = get_origin(type_) return _type_is_complex_inner(type_) or _type_is_complex_inner(origin) def _type_is_complex_inner(type_: type[Any] | None) -> bool: if lenient_issubclass(type_, (str, bytes)): return False return lenient_issubclass( type_, (BaseModel, Mapping, Sequence, tuple, set, frozenset, deque) ) or dataclasses.is_dataclass(type_) def is_coroutine_callable(call: Callable[..., Any]) -> bool: """检查 call 是否是一个 callable 协程函数""" if inspect.isroutine(call): return inspect.iscoroutinefunction(call) if inspect.isclass(call): return False func_ = getattr(call, "__call__", None) return inspect.iscoroutinefunction(func_) def is_gen_callable(call: Callable[..., Any]) -> bool: """检查 call 是否是一个生成器函数""" if inspect.isgeneratorfunction(call): return True func_ = getattr(call, "__call__", None) return inspect.isgeneratorfunction(func_) def is_async_gen_callable(call: Callable[..., Any]) -> bool: """检查 call 是否是一个异步生成器函数""" if inspect.isasyncgenfunction(call): return True func_ = getattr(call, "__call__", None) return inspect.isasyncgenfunction(func_) def run_sync(call: Callable[P, R]) -> Callable[P, Coroutine[None, None, R]]: """一个用于包装 sync function 为 async function 的装饰器 参数: call: 被装饰的同步函数 """ @wraps(call) async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R: return await anyio.to_thread.run_sync( partial(call, *args, **kwargs), abandon_on_cancel=True ) return _wrapper @asynccontextmanager async def run_sync_ctx_manager( cm: AbstractContextManager[T], ) -> AsyncGenerator[T, None]: """一个用于包装 sync context manager 为 async context manager 的执行函数""" try: yield await run_sync(cm.__enter__)() except Exception as e: ok = await run_sync(cm.__exit__)(type(e), e, None) if not ok: raise e else: await run_sync(cm.__exit__)(None, None, None) @overload async def run_coro_with_catch( coro: Coroutine[Any, Any, T], exc: tuple[type[Exception], ...], return_on_err: None = None, ) -> T | None: ... @overload async def run_coro_with_catch( coro: Coroutine[Any, Any, T], exc: tuple[type[Exception], ...], return_on_err: R, ) -> T | R: ... async def run_coro_with_catch( coro: Coroutine[Any, Any, T], exc: tuple[type[Exception], ...], return_on_err: R | None = None, ) -> T | R | None: """运行协程并当遇到指定异常时返回指定值。 参数: coro: 要运行的协程 exc: 要捕获的异常 return_on_err: 当发生异常时返回的值 返回: 协程的返回值或发生异常时的指定值 """ with catch({exc: lambda exc_group: None}): return await coro return return_on_err async def run_coro_with_shield(coro: Coroutine[Any, Any, T]) -> T: """运行协程并在取消时屏蔽取消异常。 参数: coro: 要运行的协程 返回: 协程的返回值 """ with anyio.CancelScope(shield=True): return await coro raise RuntimeError("This should not happen") def flatten_exception_group( exc_group: BaseExceptionGroup[E], ) -> Generator[E, None, None]: for exc in exc_group.exceptions: if isinstance(exc, BaseExceptionGroup): yield from flatten_exception_group(exc) else: yield exc def get_name(obj: Any) -> str: """获取对象的名称""" if inspect.isfunction(obj) or inspect.isclass(obj): return obj.__name__ return obj.__class__.__name__ def path_to_module_name(path: Path) -> str: """转换路径为模块名""" rel_path = path.resolve().relative_to(Path.cwd().resolve()) if rel_path.stem == "__init__": return ".".join(rel_path.parts[:-1]) else: return ".".join((*rel_path.parts[:-1], rel_path.stem)) def resolve_dot_notation( obj_str: str, default_attr: str, default_prefix: str | None = None ) -> Any: """解析并导入点分表示法的对象""" modulename, _, cls = obj_str.partition(":") if default_prefix is not None and modulename.startswith("~"): modulename = default_prefix + modulename[1:] module = importlib.import_module(modulename) if not cls: return getattr(module, default_attr) instance = module for attr_str in cls.split("."): instance = getattr(instance, attr_str) return instance class classproperty(Generic[T]): """类属性装饰器""" def __init__(self, func: Callable[[Any], T]) -> None: self.func = func def __get__(self, instance: Any, owner: type[Any] | None = None) -> T: return self.func(type(instance) if owner is None else owner) class DataclassEncoder(json.JSONEncoder): """可以序列化 {ref}`nonebot.adapters.Message`(List[Dataclass]) 的 `JSONEncoder`""" @override def default(self, o): if dataclasses.is_dataclass(o): return {f.name: getattr(o, f.name) for f in dataclasses.fields(o)} return super().default(o) def logger_wrapper(logger_name: str): """用于打印 adapter 的日志。 参数: logger_name: adapter 的名称 返回: 日志记录函数 日志记录函数的参数: - level: 日志等级 - message: 日志信息 - exception: 异常信息 """ def log(level: str, message: str, exception: Exception | None = None): logger.opt(colors=True, exception=exception).log( level, f"{escape_tag(logger_name)} | {message}" ) return log