mirror of
https://github.com/nonebot/nonebot2.git
synced 2025-09-07 12:36:47 +00:00
⚗️ change rule to use handler
This commit is contained in:
@ -18,6 +18,7 @@ from nonebot.log import logger
|
||||
from .models import Param as Param
|
||||
from .utils import get_typed_signature
|
||||
from .models import Dependent as Dependent
|
||||
from nonebot.exception import SkippedException
|
||||
from .models import DependsWrapper as DependsWrapper
|
||||
from nonebot.utils import (run_sync, is_gen_callable, run_sync_ctx_manager,
|
||||
is_async_gen_callable, is_coroutine_callable)
|
||||
@ -112,21 +113,20 @@ def get_dependent(*,
|
||||
|
||||
|
||||
async def solve_dependencies(
|
||||
*,
|
||||
dependent: Dependent,
|
||||
stack: Optional[AsyncExitStack] = None,
|
||||
sub_dependents: Optional[List[Dependent]] = None,
|
||||
dependency_overrides_provider: Optional[Any] = None,
|
||||
dependency_cache: Optional[Dict[Callable[..., Any], Any]] = None,
|
||||
**params: Any
|
||||
) -> Tuple[Dict[str, Any], Dict[Callable[..., Any], Any], bool]:
|
||||
*,
|
||||
_dependent: Dependent,
|
||||
_stack: Optional[AsyncExitStack] = None,
|
||||
_sub_dependents: Optional[List[Dependent]] = None,
|
||||
_dependency_overrides_provider: Optional[Any] = None,
|
||||
_dependency_cache: Optional[Dict[Callable[..., Any], Any]] = None,
|
||||
**params: Any) -> Tuple[Dict[str, Any], Dict[Callable[..., Any], Any]]:
|
||||
values: Dict[str, Any] = {}
|
||||
dependency_cache = dependency_cache or {}
|
||||
dependency_cache = _dependency_cache or {}
|
||||
|
||||
# solve sub dependencies
|
||||
sub_dependent: Dependent
|
||||
for sub_dependent in chain(sub_dependents or tuple(),
|
||||
dependent.dependencies):
|
||||
for sub_dependent in chain(_sub_dependents or tuple(),
|
||||
_dependent.dependencies):
|
||||
sub_dependent.func = cast(Callable[..., Any], sub_dependent.func)
|
||||
sub_dependent.cache_key = cast(Callable[..., Any],
|
||||
sub_dependent.cache_key)
|
||||
@ -134,10 +134,10 @@ async def solve_dependencies(
|
||||
|
||||
# dependency overrides
|
||||
use_sub_dependant = sub_dependent
|
||||
if (dependency_overrides_provider and
|
||||
hasattr(dependency_overrides_provider, "dependency_overrides")):
|
||||
if (_dependency_overrides_provider and hasattr(
|
||||
_dependency_overrides_provider, "dependency_overrides")):
|
||||
original_call = sub_dependent.func
|
||||
func = getattr(dependency_overrides_provider,
|
||||
func = getattr(_dependency_overrides_provider,
|
||||
"dependency_overrides",
|
||||
{}).get(original_call, original_call)
|
||||
use_sub_dependant = get_dependent(
|
||||
@ -148,13 +148,11 @@ async def solve_dependencies(
|
||||
|
||||
# solve sub dependency with current cache
|
||||
solved_result = await solve_dependencies(
|
||||
dependent=use_sub_dependant,
|
||||
dependency_overrides_provider=dependency_overrides_provider,
|
||||
_dependent=use_sub_dependant,
|
||||
_dependency_overrides_provider=_dependency_overrides_provider,
|
||||
dependency_cache=dependency_cache,
|
||||
**params)
|
||||
sub_values, sub_dependency_cache, ignored = solved_result
|
||||
if ignored:
|
||||
return values, dependency_cache, True
|
||||
sub_values, sub_dependency_cache = solved_result
|
||||
# update cache?
|
||||
dependency_cache.update(sub_dependency_cache)
|
||||
|
||||
@ -163,13 +161,13 @@ async def solve_dependencies(
|
||||
solved = dependency_cache[sub_dependent.cache_key]
|
||||
elif is_gen_callable(func) or is_async_gen_callable(func):
|
||||
assert isinstance(
|
||||
stack, AsyncExitStack
|
||||
_stack, AsyncExitStack
|
||||
), "Generator dependency should be called in context"
|
||||
if is_gen_callable(func):
|
||||
cm = run_sync_ctx_manager(contextmanager(func)(**sub_values))
|
||||
else:
|
||||
cm = asynccontextmanager(func)(**sub_values)
|
||||
solved = await stack.enter_async_context(cm)
|
||||
solved = await _stack.enter_async_context(cm)
|
||||
elif is_coroutine_callable(func):
|
||||
solved = await func(**sub_values)
|
||||
else:
|
||||
@ -183,7 +181,7 @@ async def solve_dependencies(
|
||||
dependency_cache[sub_dependent.cache_key] = solved
|
||||
|
||||
# usual dependency
|
||||
for field in dependent.params:
|
||||
for field in _dependent.params:
|
||||
field_info = field.field_info
|
||||
assert isinstance(field_info,
|
||||
Param), "Params must be subclasses of Param"
|
||||
@ -194,13 +192,13 @@ async def solve_dependencies(
|
||||
if errs_:
|
||||
logger.debug(
|
||||
f"{field_info} "
|
||||
f"type {type(value)} not match depends {dependent.func} "
|
||||
f"type {type(value)} not match depends {_dependent.func} "
|
||||
f"annotation {field._type_display()}, ignored")
|
||||
return values, dependency_cache, True
|
||||
raise SkippedException
|
||||
else:
|
||||
values[field.name] = value
|
||||
|
||||
return values, dependency_cache, False
|
||||
return values, dependency_cache
|
||||
|
||||
|
||||
def Depends(dependency: Optional[Callable[..., Any]] = None,
|
||||
|
Reference in New Issue
Block a user