improve dependency cache

This commit is contained in:
yanyongyu
2021-12-16 23:22:25 +08:00
parent fe69735ca0
commit 3d762fcbab
15 changed files with 162 additions and 100 deletions

View File

@ -1,12 +1,13 @@
import asyncio
import inspect
from typing import Any, Dict, List, Tuple, Callable, Optional, cast
from contextlib import AsyncExitStack, contextmanager, asynccontextmanager
from pydantic.fields import Required, Undefined
from nonebot.typing import T_State, T_Handler
from nonebot.adapters import Bot, Event, Message
from nonebot.dependencies import Param, Dependent
from nonebot.typing import T_State, T_Handler, T_DependencyCache
from nonebot.consts import (
CMD_KEY,
PREFIX_KEY,
@ -19,7 +20,6 @@ from nonebot.consts import (
REGEX_MATCHED,
)
from nonebot.utils import (
CacheDict,
get_name,
run_sync,
is_gen_callable,
@ -49,7 +49,7 @@ class DependsInner:
def Depends(
dependency: Optional[T_Handler] = None,
*,
use_cache: bool = False,
use_cache: bool = True,
) -> Any:
"""
:说明:
@ -114,11 +114,11 @@ class DependParam(Param):
async def _solve(
self,
stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[CacheDict[T_Handler, Any]] = None,
dependency_cache: Optional[T_DependencyCache] = None,
**kwargs: Any,
) -> Any:
use_cache: bool = self.extra["use_cache"]
dependency_cache = CacheDict() if dependency_cache is None else dependency_cache
dependency_cache = {} if dependency_cache is None else dependency_cache
sub_dependent: Dependent = self.extra["dependent"]
sub_dependent.call = cast(Callable[..., Any], sub_dependent.call)
@ -132,26 +132,28 @@ class DependParam(Param):
)
# run dependency function
async with dependency_cache:
if use_cache and call in dependency_cache:
solved = dependency_cache[call]
elif is_gen_callable(call) or is_async_gen_callable(call):
assert isinstance(
stack, AsyncExitStack
), "Generator dependency should be called in context"
if is_gen_callable(call):
cm = run_sync_ctx_manager(contextmanager(call)(**sub_values))
else:
cm = asynccontextmanager(call)(**sub_values)
solved = await stack.enter_async_context(cm)
elif is_coroutine_callable(call):
return await call(**sub_values)
task: asyncio.Task[Any]
if use_cache and call in dependency_cache:
solved = await dependency_cache[call]
elif is_gen_callable(call) or is_async_gen_callable(call):
assert isinstance(
stack, AsyncExitStack
), "Generator dependency should be called in context"
if is_gen_callable(call):
cm = run_sync_ctx_manager(contextmanager(call)(**sub_values))
else:
return await run_sync(call)(**sub_values)
# save current dependency to cache
if call not in dependency_cache:
dependency_cache[call] = solved
cm = asynccontextmanager(call)(**sub_values)
task = asyncio.create_task(stack.enter_async_context(cm))
dependency_cache[call] = task
solved = await task
elif is_coroutine_callable(call):
task = asyncio.create_task(call(**sub_values))
dependency_cache[call] = task
solved = await task
else:
task = asyncio.create_task(run_sync(call)(**sub_values))
dependency_cache[call] = task
solved = await task
return solved
@ -243,7 +245,7 @@ def _command(state=State()) -> Message:
def Command() -> Tuple[str, ...]:
return Depends(_command)
return Depends(_command, use_cache=False)
def _raw_command(state=State()) -> Message:
@ -251,7 +253,7 @@ def _raw_command(state=State()) -> Message:
def RawCommand() -> str:
return Depends(_raw_command)
return Depends(_raw_command, use_cache=False)
def _command_arg(state=State()) -> Message:
@ -259,7 +261,7 @@ def _command_arg(state=State()) -> Message:
def CommandArg() -> Message:
return Depends(_command_arg)
return Depends(_command_arg, use_cache=False)
def _shell_command_args(state=State()) -> Any:
@ -267,7 +269,7 @@ def _shell_command_args(state=State()) -> Any:
def ShellCommandArgs():
return Depends(_shell_command_args)
return Depends(_shell_command_args, use_cache=False)
def _shell_command_argv(state=State()) -> List[str]:
@ -275,7 +277,7 @@ def _shell_command_argv(state=State()) -> List[str]:
def ShellCommandArgv() -> Any:
return Depends(_shell_command_argv)
return Depends(_shell_command_argv, use_cache=False)
def _regex_matched(state=State()) -> str:
@ -283,7 +285,7 @@ def _regex_matched(state=State()) -> str:
def RegexMatched() -> str:
return Depends(_regex_matched)
return Depends(_regex_matched, use_cache=False)
def _regex_group(state=State()):
@ -291,7 +293,7 @@ def _regex_group(state=State()):
def RegexGroup() -> Tuple[Any, ...]:
return Depends(_regex_group)
return Depends(_regex_group, use_cache=False)
def _regex_dict(state=State()):
@ -299,7 +301,7 @@ def _regex_dict(state=State()):
def RegexDict() -> Dict[str, Any]:
return Depends(_regex_dict)
return Depends(_regex_dict, use_cache=False)
class MatcherParam(Param):
@ -320,14 +322,14 @@ def Received(id: str, default: Any = None) -> Any:
def _received(matcher: "Matcher"):
return matcher.get_receive(id, default)
return Depends(_received)
return Depends(_received, use_cache=False)
def LastReceived(default: Any = None) -> Any:
def _last_received(matcher: "Matcher") -> Any:
return matcher.get_receive(None, default)
return Depends(_last_received)
return Depends(_last_received, use_cache=False)
class ExceptionParam(Param):