mirror of
https://github.com/nonebot/nonebot2.git
synced 2025-07-16 02:50:48 +00:00
⚡ improve dependency cache
This commit is contained in:
@ -1,12 +1,13 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
from typing import Any, Dict, List, Tuple, Callable, Optional, cast
|
||||
from contextlib import AsyncExitStack, contextmanager, asynccontextmanager
|
||||
|
||||
from pydantic.fields import Required, Undefined
|
||||
|
||||
from nonebot.typing import T_State, T_Handler
|
||||
from nonebot.adapters import Bot, Event, Message
|
||||
from nonebot.dependencies import Param, Dependent
|
||||
from nonebot.typing import T_State, T_Handler, T_DependencyCache
|
||||
from nonebot.consts import (
|
||||
CMD_KEY,
|
||||
PREFIX_KEY,
|
||||
@ -19,7 +20,6 @@ from nonebot.consts import (
|
||||
REGEX_MATCHED,
|
||||
)
|
||||
from nonebot.utils import (
|
||||
CacheDict,
|
||||
get_name,
|
||||
run_sync,
|
||||
is_gen_callable,
|
||||
@ -49,7 +49,7 @@ class DependsInner:
|
||||
def Depends(
|
||||
dependency: Optional[T_Handler] = None,
|
||||
*,
|
||||
use_cache: bool = False,
|
||||
use_cache: bool = True,
|
||||
) -> Any:
|
||||
"""
|
||||
:说明:
|
||||
@ -114,11 +114,11 @@ class DependParam(Param):
|
||||
async def _solve(
|
||||
self,
|
||||
stack: Optional[AsyncExitStack] = None,
|
||||
dependency_cache: Optional[CacheDict[T_Handler, Any]] = None,
|
||||
dependency_cache: Optional[T_DependencyCache] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
use_cache: bool = self.extra["use_cache"]
|
||||
dependency_cache = CacheDict() if dependency_cache is None else dependency_cache
|
||||
dependency_cache = {} if dependency_cache is None else dependency_cache
|
||||
|
||||
sub_dependent: Dependent = self.extra["dependent"]
|
||||
sub_dependent.call = cast(Callable[..., Any], sub_dependent.call)
|
||||
@ -132,26 +132,28 @@ class DependParam(Param):
|
||||
)
|
||||
|
||||
# run dependency function
|
||||
async with dependency_cache:
|
||||
if use_cache and call in dependency_cache:
|
||||
solved = dependency_cache[call]
|
||||
elif is_gen_callable(call) or is_async_gen_callable(call):
|
||||
assert isinstance(
|
||||
stack, AsyncExitStack
|
||||
), "Generator dependency should be called in context"
|
||||
if is_gen_callable(call):
|
||||
cm = run_sync_ctx_manager(contextmanager(call)(**sub_values))
|
||||
else:
|
||||
cm = asynccontextmanager(call)(**sub_values)
|
||||
solved = await stack.enter_async_context(cm)
|
||||
elif is_coroutine_callable(call):
|
||||
return await call(**sub_values)
|
||||
task: asyncio.Task[Any]
|
||||
if use_cache and call in dependency_cache:
|
||||
solved = await dependency_cache[call]
|
||||
elif is_gen_callable(call) or is_async_gen_callable(call):
|
||||
assert isinstance(
|
||||
stack, AsyncExitStack
|
||||
), "Generator dependency should be called in context"
|
||||
if is_gen_callable(call):
|
||||
cm = run_sync_ctx_manager(contextmanager(call)(**sub_values))
|
||||
else:
|
||||
return await run_sync(call)(**sub_values)
|
||||
|
||||
# save current dependency to cache
|
||||
if call not in dependency_cache:
|
||||
dependency_cache[call] = solved
|
||||
cm = asynccontextmanager(call)(**sub_values)
|
||||
task = asyncio.create_task(stack.enter_async_context(cm))
|
||||
dependency_cache[call] = task
|
||||
solved = await task
|
||||
elif is_coroutine_callable(call):
|
||||
task = asyncio.create_task(call(**sub_values))
|
||||
dependency_cache[call] = task
|
||||
solved = await task
|
||||
else:
|
||||
task = asyncio.create_task(run_sync(call)(**sub_values))
|
||||
dependency_cache[call] = task
|
||||
solved = await task
|
||||
|
||||
return solved
|
||||
|
||||
@ -243,7 +245,7 @@ def _command(state=State()) -> Message:
|
||||
|
||||
|
||||
def Command() -> Tuple[str, ...]:
|
||||
return Depends(_command)
|
||||
return Depends(_command, use_cache=False)
|
||||
|
||||
|
||||
def _raw_command(state=State()) -> Message:
|
||||
@ -251,7 +253,7 @@ def _raw_command(state=State()) -> Message:
|
||||
|
||||
|
||||
def RawCommand() -> str:
|
||||
return Depends(_raw_command)
|
||||
return Depends(_raw_command, use_cache=False)
|
||||
|
||||
|
||||
def _command_arg(state=State()) -> Message:
|
||||
@ -259,7 +261,7 @@ def _command_arg(state=State()) -> Message:
|
||||
|
||||
|
||||
def CommandArg() -> Message:
|
||||
return Depends(_command_arg)
|
||||
return Depends(_command_arg, use_cache=False)
|
||||
|
||||
|
||||
def _shell_command_args(state=State()) -> Any:
|
||||
@ -267,7 +269,7 @@ def _shell_command_args(state=State()) -> Any:
|
||||
|
||||
|
||||
def ShellCommandArgs():
|
||||
return Depends(_shell_command_args)
|
||||
return Depends(_shell_command_args, use_cache=False)
|
||||
|
||||
|
||||
def _shell_command_argv(state=State()) -> List[str]:
|
||||
@ -275,7 +277,7 @@ def _shell_command_argv(state=State()) -> List[str]:
|
||||
|
||||
|
||||
def ShellCommandArgv() -> Any:
|
||||
return Depends(_shell_command_argv)
|
||||
return Depends(_shell_command_argv, use_cache=False)
|
||||
|
||||
|
||||
def _regex_matched(state=State()) -> str:
|
||||
@ -283,7 +285,7 @@ def _regex_matched(state=State()) -> str:
|
||||
|
||||
|
||||
def RegexMatched() -> str:
|
||||
return Depends(_regex_matched)
|
||||
return Depends(_regex_matched, use_cache=False)
|
||||
|
||||
|
||||
def _regex_group(state=State()):
|
||||
@ -291,7 +293,7 @@ def _regex_group(state=State()):
|
||||
|
||||
|
||||
def RegexGroup() -> Tuple[Any, ...]:
|
||||
return Depends(_regex_group)
|
||||
return Depends(_regex_group, use_cache=False)
|
||||
|
||||
|
||||
def _regex_dict(state=State()):
|
||||
@ -299,7 +301,7 @@ def _regex_dict(state=State()):
|
||||
|
||||
|
||||
def RegexDict() -> Dict[str, Any]:
|
||||
return Depends(_regex_dict)
|
||||
return Depends(_regex_dict, use_cache=False)
|
||||
|
||||
|
||||
class MatcherParam(Param):
|
||||
@ -320,14 +322,14 @@ def Received(id: str, default: Any = None) -> Any:
|
||||
def _received(matcher: "Matcher"):
|
||||
return matcher.get_receive(id, default)
|
||||
|
||||
return Depends(_received)
|
||||
return Depends(_received, use_cache=False)
|
||||
|
||||
|
||||
def LastReceived(default: Any = None) -> Any:
|
||||
def _last_received(matcher: "Matcher") -> Any:
|
||||
return matcher.get_receive(None, default)
|
||||
|
||||
return Depends(_last_received)
|
||||
return Depends(_last_received, use_cache=False)
|
||||
|
||||
|
||||
class ExceptionParam(Param):
|
||||
|
Reference in New Issue
Block a user