♻️ use class rule and permission

This commit is contained in:
yanyongyu
2021-12-06 10:10:51 +08:00
parent ca4d7397f8
commit 5b75b72720
8 changed files with 202 additions and 135 deletions

View File

@ -69,22 +69,22 @@ def get_sub_dependant(
allow_types: Optional[List[Type[Param]]] = None,
) -> Dependent:
sub_dependant = get_dependent(
func=dependency, name=name, use_cache=depends.use_cache, allow_types=allow_types
call=dependency, name=name, use_cache=depends.use_cache, allow_types=allow_types
)
return sub_dependant
def get_dependent(
*,
func: T_Handler,
call: T_Handler,
name: Optional[str] = None,
use_cache: bool = True,
allow_types: Optional[List[Type[Param]]] = None,
) -> Dependent:
signature = get_typed_signature(func)
signature = get_typed_signature(call)
params = signature.parameters
dependent = Dependent(
func=func, name=name, allow_types=allow_types, use_cache=use_cache
call=call, name=name, allow_types=allow_types, use_cache=use_cache
)
for param_name, param in params.items():
if isinstance(param.default, DependsWrapper):
@ -108,7 +108,7 @@ def get_dependent(
break
else:
raise ValueError(
f"Unknown parameter {param_name} for function {func} with type {param.annotation}"
f"Unknown parameter {param_name} for function {call} with type {param.annotation}"
)
annotation: Any = Any
@ -153,7 +153,7 @@ async def solve_dependencies(
if errs_:
logger.debug(
f"{field_info} "
f"type {type(value)} not match depends {_dependent.func} "
f"type {type(value)} not match depends {_dependent.call} "
f"annotation {field._type_display()}, ignored"
)
raise SkippedException(field, value)
@ -163,9 +163,9 @@ async def solve_dependencies(
# solve sub dependencies
sub_dependent: Dependent
for sub_dependent in chain(_sub_dependents or tuple(), _dependent.dependencies):
sub_dependent.func = cast(Callable[..., Any], sub_dependent.func)
sub_dependent.call = cast(Callable[..., Any], sub_dependent.call)
sub_dependent.cache_key = cast(Callable[..., Any], sub_dependent.cache_key)
func = sub_dependent.func
call = sub_dependent.call
# solve sub dependency with current cache
solved_result = await solve_dependencies(
@ -179,19 +179,19 @@ async def solve_dependencies(
async with cache_lock:
if sub_dependent.use_cache and sub_dependent.cache_key in dependency_cache:
solved = dependency_cache[sub_dependent.cache_key]
elif is_gen_callable(func) or is_async_gen_callable(func):
elif is_gen_callable(call) or is_async_gen_callable(call):
assert isinstance(
_stack, AsyncExitStack
), "Generator dependency should be called in context"
if is_gen_callable(func):
cm = run_sync_ctx_manager(contextmanager(func)(**sub_values))
if is_gen_callable(call):
cm = run_sync_ctx_manager(contextmanager(call)(**sub_values))
else:
cm = asynccontextmanager(func)(**sub_values)
cm = asynccontextmanager(call)(**sub_values)
solved = await _stack.enter_async_context(cm)
elif is_coroutine_callable(func):
solved = await func(**sub_values)
elif is_coroutine_callable(call):
solved = await call(**sub_values)
else:
solved = await run_sync(func)(**sub_values)
solved = await run_sync(call)(**sub_values)
# parameter dependency
if sub_dependent.name is not None:

View File

@ -36,17 +36,17 @@ class Dependent:
def __init__(
self,
*,
func: Optional[T_Handler] = None,
call: Optional[T_Handler] = None,
name: Optional[str] = None,
params: Optional[List[ModelField]] = None,
allow_types: Optional[List[Type[Param]]] = None,
dependencies: Optional[List["Dependent"]] = None,
use_cache: bool = True,
) -> None:
self.func = func
self.call = call
self.name = name
self.params = params or []
self.allow_types = allow_types or []
self.dependencies = dependencies or []
self.use_cache = use_cache
self.cache_key = self.func
self.cache_key = self.call

View File

@ -7,9 +7,9 @@ from pydantic.typing import ForwardRef, evaluate_forwardref
from nonebot.typing import T_Handler
def get_typed_signature(func: T_Handler) -> inspect.Signature:
signature = inspect.signature(func)
globalns = getattr(func, "__globals__", {})
def get_typed_signature(call: T_Handler) -> inspect.Signature:
signature = inspect.signature(call)
globalns = getattr(call, "__globals__", {})
typed_params = [
inspect.Parameter(
name=param.name,