🚧 add generator dependency support

This commit is contained in:
yanyongyu
2021-11-15 01:28:47 +08:00
parent 0a1ae75b70
commit cafe5c9af0
5 changed files with 75 additions and 17 deletions

View File

@@ -1,15 +1,17 @@
import inspect
from itertools import chain
from typing import Any, Dict, List, Tuple, Callable, Optional, cast
from contextlib import AsyncExitStack, contextmanager, asynccontextmanager
from .models import Dependent
from nonebot.log import logger
from nonebot.typing import T_State
from nonebot.adapters import Bot, Event
from .models import Depends as DependsClass
from nonebot.utils import run_sync, is_coroutine_callable
from .utils import (generic_get_types, get_typed_signature,
generic_check_issubclass)
from nonebot.utils import (run_sync, is_gen_callable, run_sync_ctx_manager,
is_async_gen_callable, is_coroutine_callable)
def get_param_sub_dependent(*, param: inspect.Parameter) -> Dependent:
@@ -95,11 +97,12 @@ async def solve_dependencies(
bot: Bot,
event: Event,
state: T_State,
matcher: "Matcher",
matcher: Optional["Matcher"],
stack: Optional[AsyncExitStack] = None,
sub_dependents: Optional[List[Dependent]] = None,
dependency_overrides_provider: Optional[Any] = None,
dependency_cache: Optional[Dict[Tuple[Callable[..., Any]], Any]] = None,
) -> Tuple[Dict[str, Any], Dict[Tuple[Callable[..., Any]], Any], bool]:
dependency_cache: Optional[Dict[Callable[..., Any], Any]] = None,
) -> Tuple[Dict[str, Any], Dict[Callable[..., Any], Any], bool]:
values: Dict[str, Any] = {}
dependency_cache = dependency_cache or {}
@@ -108,7 +111,7 @@ async def solve_dependencies(
for sub_dependent in chain(sub_dependents or tuple(),
dependent.dependencies):
sub_dependent.func = cast(Callable[..., Any], sub_dependent.func)
sub_dependent.cache_key = cast(Tuple[Callable[..., Any]],
sub_dependent.cache_key = cast(Callable[..., Any],
sub_dependent.cache_key)
func = sub_dependent.func
@@ -158,6 +161,15 @@ async def solve_dependencies(
# run dependency function
if sub_dependent.use_cache and sub_dependent.cache_key in dependency_cache:
solved = dependency_cache[sub_dependent.cache_key]
elif is_gen_callable(func) or is_async_gen_callable(func):
assert isinstance(
stack, AsyncExitStack
), "Generator dependency should be called in context"
if is_gen_callable(func):
cm = run_sync_ctx_manager(contextmanager(func)(**sub_values))
else:
cm = asynccontextmanager(func)(**sub_values)
solved = await stack.enter_async_context(cm)
elif is_coroutine_callable(func):
solved = await func(**sub_values)
else:

View File

@@ -6,6 +6,7 @@
"""
import asyncio
from contextlib import AsyncExitStack
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Callable, Optional
from nonebot.log import logger
@@ -37,15 +38,15 @@ class Handler:
self.name = get_name(func) if name is None else name
self.dependencies = dependencies or []
self.sub_dependents: Dict[Tuple[Callable[..., Any]], Dependent] = {}
self.sub_dependents: Dict[Callable[..., Any], Dependent] = {}
if dependencies:
for depends in dependencies:
if not depends.dependency:
raise ValueError(f"{depends} has no dependency")
if (depends.dependency,) in self.sub_dependents:
if depends.dependency in self.sub_dependents:
raise ValueError(f"{depends} is already in dependencies")
sub_dependant = get_parameterless_sub_dependant(depends=depends)
self.sub_dependents[(depends.dependency,)] = sub_dependant
self.sub_dependents[depends.dependency] = sub_dependant
self.dependency_overrides_provider = dependency_overrides_provider
self.dependent = get_dependent(func=func)
@@ -60,19 +61,29 @@ class Handler:
def __str__(self) -> str:
return repr(self)
async def __call__(self, matcher: "Matcher", bot: "Bot", event: "Event",
state: T_State):
async def __call__(
self,
matcher: "Matcher",
bot: "Bot",
event: "Event",
state: T_State,
*,
stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[Dict[Callable[..., Any],
Any]] = None) -> Any:
values, _, ignored = await solve_dependencies(
dependent=self.dependent,
bot=bot,
event=event,
state=state,
matcher=matcher,
stack=stack,
sub_dependents=[
self.sub_dependents[(dependency.dependency,)] # type: ignore
self.sub_dependents[dependency.dependency] # type: ignore
for dependency in self.dependencies
],
dependency_overrides_provider=self.dependency_overrides_provider)
dependency_overrides_provider=self.dependency_overrides_provider,
dependency_cache=dependency_cache)
if ignored:
return
@@ -101,7 +112,7 @@ class Handler:
if (dependency.dependency,) in self.sub_dependents:
raise ValueError(f"{dependency} is already in dependencies")
sub_dependant = get_parameterless_sub_dependant(depends=dependency)
self.sub_dependents[(dependency.dependency,)] = sub_dependant
self.sub_dependents[dependency.dependency] = sub_dependant
def prepend_dependency(self, dependency: Depends):
self.cache_dependent(dependency)
@@ -114,7 +125,7 @@ class Handler:
def remove_dependency(self, dependency: Depends):
if not dependency.dependency:
raise ValueError(f"{dependency} has no dependency")
if (dependency.dependency,) in self.sub_dependents:
del self.sub_dependents[(dependency.dependency,)]
if dependency.dependency in self.sub_dependents:
del self.sub_dependents[dependency.dependency]
if dependency in self.dependencies:
self.dependencies.remove(dependency)

View File

@@ -45,4 +45,4 @@ class Dependent:
self.matcher_param_name = matcher_param_name
self.dependencies = dependencies or []
self.use_cache = use_cache
self.cache_key = (self.func,)
self.cache_key = self.func