🐛 fix cache concurrency

This commit is contained in:
yanyongyu
2021-11-21 15:46:48 +08:00
parent d22630e768
commit 75d4cd9565
8 changed files with 162 additions and 73 deletions

View File

@@ -21,8 +21,11 @@ from .models import Dependent as Dependent
from nonebot.exception import SkippedException
from .models import DependsWrapper as DependsWrapper
from nonebot.typing import T_Handler, T_DependencyCache
from nonebot.utils import (run_sync, is_gen_callable, run_sync_ctx_manager,
is_async_gen_callable, is_coroutine_callable)
from nonebot.utils import (CacheLock, run_sync, is_gen_callable,
run_sync_ctx_manager, is_async_gen_callable,
is_coroutine_callable)
cache_lock = CacheLock()
class CustomConfig(BaseConfig):
@@ -93,7 +96,7 @@ def get_dependent(*,
break
else:
raise ValueError(
f"Unknown parameter {param_name} for funcction {func} with type {param.annotation}"
f"Unknown parameter {param_name} for function {func} with type {param.annotation}"
)
annotation: Any = Any
@@ -122,7 +125,7 @@ async def solve_dependencies(
_dependency_cache: Optional[T_DependencyCache] = None,
**params: Any) -> Tuple[Dict[str, Any], T_DependencyCache]:
values: Dict[str, Any] = {}
dependency_cache = _dependency_cache or {}
dependency_cache = {} if _dependency_cache is None else _dependency_cache
# solve sub dependencies
sub_dependent: Dependent
@@ -151,13 +154,14 @@ async def solve_dependencies(
solved_result = await solve_dependencies(
_dependent=use_sub_dependant,
_dependency_overrides_provider=_dependency_overrides_provider,
dependency_cache=dependency_cache,
_dependency_cache=dependency_cache,
**params)
sub_values, sub_dependency_cache = solved_result
# update cache?
dependency_cache.update(sub_dependency_cache)
# dependency_cache.update(sub_dependency_cache)
# run dependency function
async with cache_lock:
if sub_dependent.use_cache and sub_dependent.cache_key in dependency_cache:
solved = dependency_cache[sub_dependent.cache_key]
elif is_gen_callable(func) or is_async_gen_callable(func):
@@ -165,7 +169,8 @@ async def solve_dependencies(
_stack, AsyncExitStack
), "Generator dependency should be called in context"
if is_gen_callable(func):
cm = run_sync_ctx_manager(contextmanager(func)(**sub_values))
cm = run_sync_ctx_manager(
contextmanager(func)(**sub_values))
else:
cm = asynccontextmanager(func)(**sub_values)
solved = await _stack.enter_async_context(cm)

View File

@@ -80,7 +80,7 @@ class Handler:
_dependency_cache: Optional[Dict[Callable[..., Any],
Any]] = None,
**params) -> Any:
values, cache = await solve_dependencies(
values, _ = await solve_dependencies(
_dependent=self.dependent,
_stack=_stack,
_sub_dependents=[

View File

@@ -163,7 +163,7 @@ async def _run_matcher(
try:
logger.debug(f"Running matcher {matcher}")
await matcher.run(bot, event, state)
await matcher.run(bot, event, state, stack, dependency_cache)
except Exception as e:
logger.opt(colors=True, exception=e).error(
f"<r><bg #f8bbd0>Running matcher {matcher} failed.</bg #f8bbd0></r>"
@@ -260,7 +260,8 @@ async def handle_event(bot: "Bot", event: "Event") -> None:
logger.debug(f"Checking for matchers in priority {priority}...")
pending_tasks = [
_check_matcher(priority, matcher, bot, event, state.copy())
_check_matcher(priority, matcher, bot, event, state.copy(),
stack, dependency_cache)
for matcher in matchers[priority]
]

View File

@@ -42,15 +42,17 @@ class Permission:
]
def __init__(self,
*checkers: T_PermissionChecker,
*checkers: Union[T_PermissionChecker, Handler],
dependency_overrides_provider: Optional[Any] = None) -> None:
"""
:参数:
* ``*checkers: T_PermissionChecker``: PermissionChecker
* ``*checkers: Union[T_PermissionChecker, Handler]``: PermissionChecker
"""
self.checkers = set(
Handler(checker,
checker if isinstance(checker, Handler) else Handler(
checker,
allow_types=self.HANDLER_PARAM_TYPES,
dependency_overrides_provider=dependency_overrides_provider)
for checker in checkers)
@@ -90,11 +92,11 @@ class Permission:
if not self.checkers:
return True
results = await asyncio.gather(
checker(bot=bot,
*(checker(bot=bot,
event=event,
_stack=stack,
_dependency_cache=dependency_cache)
for checker in self.checkers)
for checker in self.checkers))
return any(results)
def __and__(self, other) -> NoReturn:
@@ -111,19 +113,19 @@ class Permission:
return Permission(*self.checkers, other)
async def _message(bot: Bot, event: Event) -> bool:
async def _message(event: Event) -> bool:
return event.get_type() == "message"
async def _notice(bot: Bot, event: Event) -> bool:
async def _notice(event: Event) -> bool:
return event.get_type() == "notice"
async def _request(bot: Bot, event: Event) -> bool:
async def _request(event: Event) -> bool:
return event.get_type() == "request"
async def _metaevent(bot: Bot, event: Event) -> bool:
async def _metaevent(event: Event) -> bool:
return event.get_type() == "meta_event"

View File

@@ -69,16 +69,17 @@ class Rule:
]
def __init__(self,
*checkers: T_RuleChecker,
*checkers: Union[T_RuleChecker, Handler],
dependency_overrides_provider: Optional[Any] = None) -> None:
"""
:参数:
* ``*checkers: T_RuleChecker``: RuleChecker
* ``*checkers: Union[T_RuleChecker, Handler]``: RuleChecker
"""
self.checkers = set(
Handler(checker,
checker if isinstance(checker, Handler) else Handler(
checker,
allow_types=self.HANDLER_PARAM_TYPES,
dependency_overrides_provider=dependency_overrides_provider)
for checker in checkers)
@@ -120,12 +121,12 @@ class Rule:
if not self.checkers:
return True
results = await asyncio.gather(
checker(bot=bot,
*(checker(bot=bot,
event=event,
state=state,
_stack=stack,
_dependency_cache=dependency_cache)
for checker in self.checkers)
for checker in self.checkers))
return all(results)
def __and__(self, other: Optional[Union["Rule", T_RuleChecker]]) -> "Rule":

View File

@@ -2,12 +2,13 @@ import re
import json
import asyncio
import inspect
import collections
import dataclasses
from functools import wraps, partial
from contextlib import asynccontextmanager
from typing_extensions import GenericAlias # type: ignore
from typing_extensions import ParamSpec, get_args, get_origin
from typing import (Any, Type, Tuple, Union, TypeVar, Callable, Optional,
from typing import (Any, Type, Deque, Tuple, Union, TypeVar, Callable, Optional,
Awaitable, AsyncGenerator, ContextManager)
from nonebot.log import logger
@@ -120,6 +121,79 @@ def get_name(obj: Any) -> str:
return obj.__class__.__name__
class CacheLock:
def __init__(self):
self._waiters: Optional[Deque[asyncio.Future]] = None
self._locked = False
def __repr__(self):
extra = "locked" if self._locked else "unlocked"
if self._waiters:
extra = f"{extra}, waiters: {len(self._waiters)}"
return f"<{self.__class__.__name__} [{extra}]>"
async def __aenter__(self):
await self.acquire()
return None
async def __aexit__(self, exc_type, exc, tb):
self.release()
def locked(self):
return self._locked
async def acquire(self):
if (not self._locked and (self._waiters is None or
all(w.cancelled() for w in self._waiters))):
self._locked = True
return True
if self._waiters is None:
self._waiters = collections.deque()
loop = asyncio.get_running_loop()
future = loop.create_future()
self._waiters.append(future)
# Finally block should be called before the CancelledError
# handling as we don't want CancelledError to call
# _wake_up_first() and attempt to wake up itself.
try:
try:
await future
finally:
self._waiters.remove(future)
except asyncio.CancelledError:
if not self._locked:
self._wake_up_first()
raise
self._locked = True
return True
def release(self):
if self._locked:
self._locked = False
self._wake_up_first()
else:
raise RuntimeError("Lock is not acquired.")
def _wake_up_first(self):
if not self._waiters:
return
try:
future = next(iter(self._waiters))
except StopIteration:
return
# .done() necessarily means that a waiter will wake up later on and
# either take the lock, or, if it was cancelled and lock wasn't
# taken already, will hit this again and wake up a new waiter.
if not future.done():
future.set_result(True)
class DataclassEncoder(json.JSONEncoder):
"""
:说明:

View File

@@ -1,26 +1,21 @@
from typing import TYPE_CHECKING
from nonebot.adapters import Event
from nonebot.permission import Permission
from .event import PrivateMessageEvent, GroupMessageEvent
if TYPE_CHECKING:
from nonebot.adapters import Bot, Event
from .event import GroupMessageEvent, PrivateMessageEvent
async def _private(bot: "Bot", event: "Event") -> bool:
async def _private(event: Event) -> bool:
return isinstance(event, PrivateMessageEvent)
async def _private_friend(bot: "Bot", event: "Event") -> bool:
async def _private_friend(event: Event) -> bool:
return isinstance(event, PrivateMessageEvent) and event.sub_type == "friend"
async def _private_group(bot: "Bot", event: "Event") -> bool:
async def _private_group(event: Event) -> bool:
return isinstance(event, PrivateMessageEvent) and event.sub_type == "group"
async def _private_other(bot: "Bot", event: "Event") -> bool:
async def _private_other(event: Event) -> bool:
return isinstance(event, PrivateMessageEvent) and event.sub_type == "other"
@@ -42,20 +37,20 @@ PRIVATE_OTHER = Permission(_private_other)
"""
async def _group(bot: "Bot", event: "Event") -> bool:
async def _group(event: Event) -> bool:
return isinstance(event, GroupMessageEvent)
async def _group_member(bot: "Bot", event: "Event") -> bool:
async def _group_member(event: Event) -> bool:
return isinstance(event,
GroupMessageEvent) and event.sender.role == "member"
async def _group_admin(bot: "Bot", event: "Event") -> bool:
async def _group_admin(event: Event) -> bool:
return isinstance(event, GroupMessageEvent) and event.sender.role == "admin"
async def _group_owner(bot: "Bot", event: "Event") -> bool:
async def _group_owner(event: Event) -> bool:
return isinstance(event, GroupMessageEvent) and event.sender.role == "owner"

View File

@@ -1,11 +1,12 @@
from nonebot import on_command
from nonebot.log import logger
from nonebot.dependencies import Depends
from nonebot import on_command, on_message
test = on_command("123")
def depend(state: dict):
print("==== depends running =====")
return state
@@ -13,5 +14,15 @@ def depend(state: dict):
@test.got("b", prompt="b")
@test.receive()
@test.got("c", prompt="c")
async def _(state: dict = Depends(depend)):
logger.info(f"=======, {state}")
async def _(x: dict = Depends(depend)):
logger.info(f"=======, {x}")
test_cache1 = on_message()
test_cache2 = on_message()
@test_cache1.handle()
@test_cache2.handle()
async def _(x: dict = Depends(depend)):
logger.info(f"======= test, {x}")