Feature: 迁移至结构化并发框架 AnyIO (#3053)

This commit is contained in:
Ju4tCode
2024-10-26 15:36:01 +08:00
committed by GitHub
parent bd9befbb55
commit ff21ceb946
39 changed files with 5422 additions and 4080 deletions

View File

@ -9,21 +9,22 @@ FrontMatter:
import re
import json
import asyncio
import inspect
import importlib
import contextlib
import dataclasses
from pathlib import Path
from collections import deque
from contextvars import copy_context
from functools import wraps, partial
from contextlib import AbstractContextManager, asynccontextmanager
from typing_extensions import ParamSpec, get_args, override, get_origin
from collections.abc import Mapping, Sequence, Coroutine, AsyncGenerator
from typing import Any, Union, Generic, TypeVar, Callable, Optional, overload
from collections.abc import Mapping, Sequence, Coroutine, Generator, AsyncGenerator
import anyio
import anyio.to_thread
from pydantic import BaseModel
from exceptiongroup import BaseExceptionGroup, catch
from nonebot.log import logger
from nonebot.typing import (
@ -39,6 +40,7 @@ R = TypeVar("R")
T = TypeVar("T")
K = TypeVar("K")
V = TypeVar("V")
E = TypeVar("E", bound=BaseException)
def escape_tag(s: str) -> str:
@ -178,11 +180,9 @@ def run_sync(call: Callable[P, R]) -> Callable[P, Coroutine[None, None, R]]:
@wraps(call)
async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
loop = asyncio.get_running_loop()
pfunc = partial(call, *args, **kwargs)
context = copy_context()
result = await loop.run_in_executor(None, partial(context.run, pfunc))
return result
return await anyio.to_thread.run_sync(
partial(call, *args, **kwargs), abandon_on_cancel=True
)
return _wrapper
@ -234,10 +234,34 @@ async def run_coro_with_catch(
协程的返回值或发生异常时的指定值
"""
try:
with catch({exc: lambda exc_group: None}):
return await coro
except exc:
return return_on_err
return return_on_err
async def run_coro_with_shield(coro: Coroutine[Any, Any, T]) -> T:
"""运行协程并在取消时屏蔽取消异常。
参数:
coro: 要运行的协程
返回:
协程的返回值
"""
with anyio.CancelScope(shield=True):
return await coro
def flatten_exception_group(
exc_group: BaseExceptionGroup[E],
) -> Generator[E, None, None]:
for exc in exc_group.exceptions:
if isinstance(exc, BaseExceptionGroup):
yield from flatten_exception_group(exc)
else:
yield exc
def get_name(obj: Any) -> str: