⚗️ change rule to use handler

This commit is contained in:
yanyongyu
2021-11-19 18:18:53 +08:00
parent ee619a33a9
commit 471d306e13
8 changed files with 182 additions and 148 deletions

View File

@ -18,6 +18,7 @@ from nonebot.log import logger
from .models import Param as Param
from .utils import get_typed_signature
from .models import Dependent as Dependent
from nonebot.exception import SkippedException
from .models import DependsWrapper as DependsWrapper
from nonebot.utils import (run_sync, is_gen_callable, run_sync_ctx_manager,
is_async_gen_callable, is_coroutine_callable)
@ -112,21 +113,20 @@ def get_dependent(*,
async def solve_dependencies(
*,
dependent: Dependent,
stack: Optional[AsyncExitStack] = None,
sub_dependents: Optional[List[Dependent]] = None,
dependency_overrides_provider: Optional[Any] = None,
dependency_cache: Optional[Dict[Callable[..., Any], Any]] = None,
**params: Any
) -> Tuple[Dict[str, Any], Dict[Callable[..., Any], Any], bool]:
*,
_dependent: Dependent,
_stack: Optional[AsyncExitStack] = None,
_sub_dependents: Optional[List[Dependent]] = None,
_dependency_overrides_provider: Optional[Any] = None,
_dependency_cache: Optional[Dict[Callable[..., Any], Any]] = None,
**params: Any) -> Tuple[Dict[str, Any], Dict[Callable[..., Any], Any]]:
values: Dict[str, Any] = {}
dependency_cache = dependency_cache or {}
dependency_cache = _dependency_cache or {}
# solve sub dependencies
sub_dependent: Dependent
for sub_dependent in chain(sub_dependents or tuple(),
dependent.dependencies):
for sub_dependent in chain(_sub_dependents or tuple(),
_dependent.dependencies):
sub_dependent.func = cast(Callable[..., Any], sub_dependent.func)
sub_dependent.cache_key = cast(Callable[..., Any],
sub_dependent.cache_key)
@ -134,10 +134,10 @@ async def solve_dependencies(
# dependency overrides
use_sub_dependant = sub_dependent
if (dependency_overrides_provider and
hasattr(dependency_overrides_provider, "dependency_overrides")):
if (_dependency_overrides_provider and hasattr(
_dependency_overrides_provider, "dependency_overrides")):
original_call = sub_dependent.func
func = getattr(dependency_overrides_provider,
func = getattr(_dependency_overrides_provider,
"dependency_overrides",
{}).get(original_call, original_call)
use_sub_dependant = get_dependent(
@ -148,13 +148,11 @@ async def solve_dependencies(
# solve sub dependency with current cache
solved_result = await solve_dependencies(
dependent=use_sub_dependant,
dependency_overrides_provider=dependency_overrides_provider,
_dependent=use_sub_dependant,
_dependency_overrides_provider=_dependency_overrides_provider,
dependency_cache=dependency_cache,
**params)
sub_values, sub_dependency_cache, ignored = solved_result
if ignored:
return values, dependency_cache, True
sub_values, sub_dependency_cache = solved_result
# update cache?
dependency_cache.update(sub_dependency_cache)
@ -163,13 +161,13 @@ async def solve_dependencies(
solved = dependency_cache[sub_dependent.cache_key]
elif is_gen_callable(func) or is_async_gen_callable(func):
assert isinstance(
stack, AsyncExitStack
_stack, AsyncExitStack
), "Generator dependency should be called in context"
if is_gen_callable(func):
cm = run_sync_ctx_manager(contextmanager(func)(**sub_values))
else:
cm = asynccontextmanager(func)(**sub_values)
solved = await stack.enter_async_context(cm)
solved = await _stack.enter_async_context(cm)
elif is_coroutine_callable(func):
solved = await func(**sub_values)
else:
@ -183,7 +181,7 @@ async def solve_dependencies(
dependency_cache[sub_dependent.cache_key] = solved
# usual dependency
for field in dependent.params:
for field in _dependent.params:
field_info = field.field_info
assert isinstance(field_info,
Param), "Params must be subclasses of Param"
@ -194,13 +192,13 @@ async def solve_dependencies(
if errs_:
logger.debug(
f"{field_info} "
f"type {type(value)} not match depends {dependent.func} "
f"type {type(value)} not match depends {_dependent.func} "
f"annotation {field._type_display()}, ignored")
return values, dependency_cache, True
raise SkippedException
else:
values[field.name] = value
return values, dependency_cache, False
return values, dependency_cache
def Depends(dependency: Optional[Callable[..., Any]] = None,