🎨 format code using black and isort

This commit is contained in:
yanyongyu
2021-11-22 23:21:26 +08:00
parent 602185a34e
commit a98d98cd12
86 changed files with 2893 additions and 2095 deletions

View File

@ -40,8 +40,7 @@ from nonebot.log import logger, default_filter
from nonebot.drivers import Driver, ReverseDriver
try:
_dist: pkg_resources.Distribution = pkg_resources.get_distribution(
"nonebot2")
_dist: pkg_resources.Distribution = pkg_resources.get_distribution("nonebot2")
__version__ = _dist.version
VERSION = _dist.parsed_version
except pkg_resources.DistributionNotFound:
@ -100,8 +99,8 @@ def get_app() -> Any:
"""
driver = get_driver()
assert isinstance(
driver,
ReverseDriver), "app object is only available for reverse driver"
driver, ReverseDriver
), "app object is only available for reverse driver"
return driver.server_app
@ -128,8 +127,8 @@ def get_asgi() -> Any:
"""
driver = get_driver()
assert isinstance(
driver,
ReverseDriver), "asgi object is only available for reverse driver"
driver, ReverseDriver
), "asgi object is only available for reverse driver"
return driver.asgi
@ -226,17 +225,23 @@ def init(*, _env_file: Optional[str] = None, **kwargs):
if not _driver:
logger.success("NoneBot is initializing...")
env = Env()
config = Config(**kwargs,
_common_config=env.dict(),
_env_file=_env_file or f".env.{env.environment}")
config = Config(
**kwargs,
_common_config=env.dict(),
_env_file=_env_file or f".env.{env.environment}",
)
default_filter.level = (
"DEBUG" if config.debug else
"INFO") if config.log_level is None else config.log_level
("DEBUG" if config.debug else "INFO")
if config.log_level is None
else config.log_level
)
logger.opt(colors=True).info(
f"Current <y><b>Env: {escape_tag(env.environment)}</b></y>")
f"Current <y><b>Env: {escape_tag(env.environment)}</b></y>"
)
logger.opt(colors=True).debug(
f"Loaded <y><b>Config</b></y>: {escape_tag(str(config.dict()))}")
f"Loaded <y><b>Config</b></y>: {escape_tag(str(config.dict()))}"
)
modulename, _, cls = config.driver.partition(":")
module = importlib.import_module(modulename)
@ -247,10 +252,7 @@ def init(*, _env_file: Optional[str] = None, **kwargs):
_driver = DriverClass(env, config)
def run(host: Optional[str] = None,
port: Optional[int] = None,
*args,
**kwargs):
def run(host: Optional[str] = None, port: Optional[int] = None, *args, **kwargs):
"""
:说明:

View File

@ -9,13 +9,13 @@ from typing import Iterable
try:
import pkg_resources
pkg_resources.declare_namespace(__name__)
del pkg_resources
except ImportError:
import pkgutil
__path__: Iterable[str] = pkgutil.extend_path(
__path__, # type: ignore
__name__)
__path__: Iterable[str] = pkgutil.extend_path(__path__, __name__) # type: ignore
del pkgutil
except Exception:
pass

View File

@ -15,7 +15,6 @@ if TYPE_CHECKING:
class _ApiCall(Protocol):
async def __call__(self, **kwargs: Any) -> Any:
...
@ -146,7 +145,8 @@ class Bot(abc.ABC):
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running CallingAPI hook. "
"Running cancelled!</bg #f8bbd0></r>")
"Running cancelled!</bg #f8bbd0></r>"
)
exception = None
result = None
@ -157,8 +157,8 @@ class Bot(abc.ABC):
exception = e
coros = list(
map(lambda x: x(self, exception, api, data, result),
self._called_api_hook))
map(lambda x: x(self, exception, api, data, result), self._called_api_hook)
)
if coros:
try:
logger.debug("Running CalledAPI hooks...")
@ -166,16 +166,17 @@ class Bot(abc.ABC):
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running CalledAPI hook. "
"Running cancelled!</bg #f8bbd0></r>")
"Running cancelled!</bg #f8bbd0></r>"
)
if exception:
raise exception
return result
@abc.abstractmethod
async def send(self, event: "Event", message: Union[str, "Message",
"MessageSegment"],
**kwargs) -> Any:
async def send(
self, event: "Event", message: Union[str, "Message", "MessageSegment"], **kwargs
) -> Any:
"""
:说明:

View File

@ -2,9 +2,8 @@ import abc
from pydantic import BaseModel
from nonebot.utils import DataclassEncoder
from ._message import Message
from nonebot.utils import DataclassEncoder
class Event(abc.ABC, BaseModel):

View File

@ -1,8 +1,17 @@
import abc
from copy import deepcopy
from dataclasses import field, asdict, dataclass
from typing import (Any, Dict, List, Type, Union, Generic, Mapping, TypeVar,
Iterable)
from typing import (
Any,
Dict,
List,
Type,
Union,
Generic,
Mapping,
TypeVar,
Iterable,
)
from ._template import MessageTemplate
@ -14,6 +23,7 @@ TM = TypeVar("TM", bound="Message")
@dataclass
class MessageSegment(Mapping, abc.ABC, Generic[TM]):
"""消息段基类"""
type: str
"""
- 类型: ``str``
@ -82,11 +92,12 @@ class MessageSegment(Mapping, abc.ABC, Generic[TM]):
class Message(List[TMS], abc.ABC):
"""消息数组"""
def __init__(self: TM,
message: Union[str, None, Mapping, Iterable[Mapping], TMS, TM,
Any] = None,
*args,
**kwargs):
def __init__(
self: TM,
message: Union[str, None, Mapping, Iterable[Mapping], TMS, TM, Any] = None,
*args,
**kwargs,
):
"""
:参数:
@ -103,8 +114,7 @@ class Message(List[TMS], abc.ABC):
self.extend(self._construct(message))
@classmethod
def template(cls: Type[TM],
format_string: Union[str, TM]) -> MessageTemplate[TM]:
def template(cls: Type[TM], format_string: Union[str, TM]) -> MessageTemplate[TM]:
"""
:说明:
@ -156,8 +166,7 @@ class Message(List[TMS], abc.ABC):
@staticmethod
@abc.abstractmethod
def _construct(
msg: Union[str, Mapping, Iterable[Mapping], Any]) -> Iterable[TMS]:
def _construct(msg: Union[str, Mapping, Iterable[Mapping], Any]) -> Iterable[TMS]:
raise NotImplementedError
def __add__(self: TM, other: Union[str, Mapping, Iterable[Mapping]]) -> TM:

View File

@ -1,8 +1,21 @@
import inspect
import functools
from string import Formatter
from typing import (TYPE_CHECKING, Any, Set, List, Type, Tuple, Union, Generic,
Mapping, TypeVar, Sequence, cast, overload)
from typing import (
TYPE_CHECKING,
Any,
Set,
List,
Type,
Tuple,
Union,
Generic,
Mapping,
TypeVar,
Sequence,
cast,
overload,
)
if TYPE_CHECKING:
from . import Message, MessageSegment
@ -15,14 +28,15 @@ class MessageTemplate(Formatter, Generic[TF]):
"""消息模板格式化实现类"""
@overload
def __init__(self: "MessageTemplate[str]",
template: str,
factory: Type[str] = str) -> None:
def __init__(
self: "MessageTemplate[str]", template: str, factory: Type[str] = str
) -> None:
...
@overload
def __init__(self: "MessageTemplate[TM]", template: Union[str, TM],
factory: Type[TM]) -> None:
def __init__(
self: "MessageTemplate[TM]", template: Union[str, TM], factory: Type[TM]
) -> None:
...
def __init__(self, template, factory=str) -> None:
@ -51,15 +65,15 @@ class MessageTemplate(Formatter, Generic[TF]):
elif isinstance(self.template, self.factory):
template = cast("Message[MessageSegment]", self.template)
for seg in template:
msg += self.vformat(str(seg), args,
kwargs) if seg.is_text() else seg
msg += self.vformat(str(seg), args, kwargs) if seg.is_text() else seg
else:
raise TypeError('template must be a string or instance of Message!')
raise TypeError("template must be a string or instance of Message!")
return msg
def vformat(self, format_string: str, args: Sequence[Any],
kwargs: Mapping[str, Any]) -> TF:
def vformat(
self, format_string: str, args: Sequence[Any], kwargs: Mapping[str, Any]
) -> TF:
used_args = set()
result, _ = self._vformat(format_string, args, kwargs, used_args, 2)
self.check_unused_args(list(used_args), args, kwargs)
@ -79,8 +93,9 @@ class MessageTemplate(Formatter, Generic[TF]):
results: List[Any] = []
for (literal_text, field_name, format_spec,
conversion) in self.parse(format_string):
for (literal_text, field_name, format_spec, conversion) in self.parse(
format_string
):
# output the literal text
if literal_text:
@ -96,14 +111,16 @@ class MessageTemplate(Formatter, Generic[TF]):
if auto_arg_index is False:
raise ValueError(
"cannot switch from manual field specification to "
"automatic field numbering")
"automatic field numbering"
)
field_name = str(auto_arg_index)
auto_arg_index += 1
elif field_name.isdigit():
if auto_arg_index:
raise ValueError(
"cannot switch from manual field specification to "
"automatic field numbering")
"automatic field numbering"
)
# disable auto arg incrementing, if it gets
# used later on, then an exception will be raised
auto_arg_index = False
@ -132,8 +149,10 @@ class MessageTemplate(Formatter, Generic[TF]):
formatted_text = self.format_field(obj, str(format_control))
results.append(formatted_text)
return self.factory(functools.reduce(self._add, results or
[""])), auto_arg_index
return (
self.factory(functools.reduce(self._add, results or [""])),
auto_arg_index,
)
def format_field(self, value: Any, format_spec: str) -> Any:
if issubclass(self.factory, str):
@ -142,11 +161,20 @@ class MessageTemplate(Formatter, Generic[TF]):
segment_class: Type[MessageSegment] = self.factory.get_segment_class()
method = getattr(segment_class, format_spec, None)
method_type = inspect.getattr_static(segment_class, format_spec, None)
return (super().format_field(value, format_spec) if
((method is None) or
(not isinstance(method_type, (classmethod, staticmethod))
) # Only Call staticmethod or classmethod
) else method(value)) if format_spec else value
return (
(
super().format_field(value, format_spec)
if (
(method is None)
or (
not isinstance(method_type, (classmethod, staticmethod))
) # Only Call staticmethod or classmethod
)
else method(value)
)
if format_spec
else value
)
def _add(self, a: Any, b: Any) -> Any:
try:

View File

@ -20,13 +20,17 @@ from ipaddress import IPv4Address
from typing import Any, Set, Dict, Tuple, Union, Mapping, Optional
from pydantic import BaseSettings, IPvAnyAddress
from pydantic.env_settings import (SettingsError, EnvSettingsSource,
InitSettingsSource, SettingsSourceCallable,
read_env_file, env_file_sentinel)
from pydantic.env_settings import (
SettingsError,
EnvSettingsSource,
InitSettingsSource,
SettingsSourceCallable,
read_env_file,
env_file_sentinel,
)
class CustomEnvSettings(EnvSettingsSource):
def __call__(self, settings: BaseSettings) -> Dict[str, Any]:
"""
Build environment variables suitable for passing to the Model.
@ -39,15 +43,24 @@ class CustomEnvSettings(EnvSettingsSource):
env_vars = {k.lower(): v for k, v in os.environ.items()}
env_file_vars: Dict[str, Optional[str]] = {}
env_file = self.env_file if self.env_file != env_file_sentinel else settings.__config__.env_file
env_file_encoding = self.env_file_encoding if self.env_file_encoding is not None else settings.__config__.env_file_encoding
env_file = (
self.env_file
if self.env_file != env_file_sentinel
else settings.__config__.env_file
)
env_file_encoding = (
self.env_file_encoding
if self.env_file_encoding is not None
else settings.__config__.env_file_encoding
)
if env_file is not None:
env_path = Path(env_file)
if env_path.is_file():
env_file_vars = read_env_file(
env_path,
encoding=env_file_encoding,
case_sensitive=settings.__config__.case_sensitive)
case_sensitive=settings.__config__.case_sensitive,
)
env_vars = {**env_file_vars, **env_vars}
for field in settings.__fields__.values():
@ -66,14 +79,12 @@ class CustomEnvSettings(EnvSettingsSource):
try:
env_val = settings.__config__.json_loads(env_val)
except ValueError as e:
raise SettingsError(
f'error parsing JSON for "{env_name}"') from e
raise SettingsError(f'error parsing JSON for "{env_name}"') from e
d[field.alias] = env_val
if env_file_vars:
for env_name, env_val in env_file_vars.items():
if (env_val is None or
len(env_val) == 0) and env_name in env_vars:
if (env_val is None or len(env_val) == 0) and env_name in env_vars:
env_val = env_vars[env_name]
try:
if env_val:
@ -87,12 +98,10 @@ class CustomEnvSettings(EnvSettingsSource):
class BaseConfig(BaseSettings):
def __getattr__(self, name: str) -> Any:
return self.__dict__.get(name)
class Config:
@classmethod
def customise_sources(
cls,
@ -101,10 +110,14 @@ class BaseConfig(BaseSettings):
file_secret_settings: SettingsSourceCallable,
) -> Tuple[SettingsSourceCallable, ...]:
common_config = init_settings.init_kwargs.pop("_common_config", {})
return (init_settings,
CustomEnvSettings(env_settings.env_file,
env_settings.env_file_encoding),
InitSettingsSource(common_config), file_secret_settings)
return (
init_settings,
CustomEnvSettings(
env_settings.env_file, env_settings.env_file_encoding
),
InitSettingsSource(common_config),
file_secret_settings,
)
class Env(BaseConfig):
@ -135,6 +148,7 @@ class Config(BaseConfig):
除了 NoneBot 的配置项外,还可以自行添加配置项到 ``.env.{environment}`` 文件中。
这些配置将会在 json 反序列化后一起带入 ``Config`` 类中。
"""
# nonebot configs
driver: str = "nonebot.drivers.fastapi"
"""
@ -210,7 +224,7 @@ class Config(BaseConfig):
API_ROOT={"123456": "http://127.0.0.1:5700"}
"""
api_timeout: Optional[float] = 30.
api_timeout: Optional[float] = 30.0
"""
- **类型**: ``Optional[float]``
- **默认值**: ``30.``

View File

@ -21,9 +21,14 @@ from .models import Dependent as Dependent
from nonebot.exception import SkippedException
from .models import DependsWrapper as DependsWrapper
from nonebot.typing import T_Handler, T_DependencyCache
from nonebot.utils import (CacheLock, run_sync, is_gen_callable,
run_sync_ctx_manager, is_async_gen_callable,
is_coroutine_callable)
from nonebot.utils import (
CacheLock,
run_sync,
is_gen_callable,
run_sync_ctx_manager,
is_async_gen_callable,
is_coroutine_callable,
)
cache_lock = CacheLock()
@ -33,60 +38,59 @@ class CustomConfig(BaseConfig):
def get_param_sub_dependent(
*,
param: inspect.Parameter,
allow_types: Optional[List[Type[Param]]] = None) -> Dependent:
*, param: inspect.Parameter, allow_types: Optional[List[Type[Param]]] = None
) -> Dependent:
depends: DependsWrapper = param.default
if depends.dependency:
dependency = depends.dependency
else:
dependency = param.annotation
return get_sub_dependant(depends=depends,
dependency=dependency,
name=param.name,
allow_types=allow_types)
return get_sub_dependant(
depends=depends, dependency=dependency, name=param.name, allow_types=allow_types
)
def get_parameterless_sub_dependant(
*,
depends: DependsWrapper,
allow_types: Optional[List[Type[Param]]] = None) -> Dependent:
*, depends: DependsWrapper, allow_types: Optional[List[Type[Param]]] = None
) -> Dependent:
assert callable(
depends.dependency
), "A parameter-less dependency must have a callable dependency"
return get_sub_dependant(depends=depends,
dependency=depends.dependency,
allow_types=allow_types)
return get_sub_dependant(
depends=depends, dependency=depends.dependency, allow_types=allow_types
)
def get_sub_dependant(
*,
depends: DependsWrapper,
dependency: T_Handler,
name: Optional[str] = None,
allow_types: Optional[List[Type[Param]]] = None) -> Dependent:
sub_dependant = get_dependent(func=dependency,
name=name,
use_cache=depends.use_cache,
allow_types=allow_types)
*,
depends: DependsWrapper,
dependency: T_Handler,
name: Optional[str] = None,
allow_types: Optional[List[Type[Param]]] = None,
) -> Dependent:
sub_dependant = get_dependent(
func=dependency, name=name, use_cache=depends.use_cache, allow_types=allow_types
)
return sub_dependant
def get_dependent(*,
func: T_Handler,
name: Optional[str] = None,
use_cache: bool = True,
allow_types: Optional[List[Type[Param]]] = None) -> Dependent:
def get_dependent(
*,
func: T_Handler,
name: Optional[str] = None,
use_cache: bool = True,
allow_types: Optional[List[Type[Param]]] = None,
) -> Dependent:
signature = get_typed_signature(func)
params = signature.parameters
dependent = Dependent(func=func,
name=name,
allow_types=allow_types,
use_cache=use_cache)
dependent = Dependent(
func=func, name=name, allow_types=allow_types, use_cache=use_cache
)
for param_name, param in params.items():
if isinstance(param.default, DependsWrapper):
sub_dependent = get_param_sub_dependent(param=param,
allow_types=allow_types)
sub_dependent = get_param_sub_dependent(
param=param, allow_types=allow_types
)
dependent.dependencies.append(sub_dependent)
continue
@ -111,44 +115,44 @@ def get_dependent(*,
required = default_value == Required
if param.annotation != param.empty:
annotation = param.annotation
annotation = get_annotation_from_field_info(annotation, field_info,
param_name)
annotation = get_annotation_from_field_info(annotation, field_info, param_name)
dependent.params.append(
ModelField(name=param_name,
type_=annotation,
class_validators=None,
model_config=CustomConfig,
default=None if required else default_value,
required=required,
field_info=field_info))
ModelField(
name=param_name,
type_=annotation,
class_validators=None,
model_config=CustomConfig,
default=None if required else default_value,
required=required,
field_info=field_info,
)
)
return dependent
async def solve_dependencies(
*,
_dependent: Dependent,
_stack: Optional[AsyncExitStack] = None,
_sub_dependents: Optional[List[Dependent]] = None,
_dependency_cache: Optional[T_DependencyCache] = None,
**params: Any) -> Tuple[Dict[str, Any], T_DependencyCache]:
*,
_dependent: Dependent,
_stack: Optional[AsyncExitStack] = None,
_sub_dependents: Optional[List[Dependent]] = None,
_dependency_cache: Optional[T_DependencyCache] = None,
**params: Any,
) -> Tuple[Dict[str, Any], T_DependencyCache]:
values: Dict[str, Any] = {}
dependency_cache = {} if _dependency_cache is None else _dependency_cache
# solve sub dependencies
sub_dependent: Dependent
for sub_dependent in chain(_sub_dependents or tuple(),
_dependent.dependencies):
for sub_dependent in chain(_sub_dependents or tuple(), _dependent.dependencies):
sub_dependent.func = cast(Callable[..., Any], sub_dependent.func)
sub_dependent.cache_key = cast(Callable[..., Any],
sub_dependent.cache_key)
sub_dependent.cache_key = cast(Callable[..., Any], sub_dependent.cache_key)
func = sub_dependent.func
# solve sub dependency with current cache
solved_result = await solve_dependencies(
_dependent=sub_dependent,
_dependency_cache=dependency_cache,
**params)
_dependent=sub_dependent, _dependency_cache=dependency_cache, **params
)
sub_values, sub_dependency_cache = solved_result
# update cache?
# dependency_cache.update(sub_dependency_cache)
@ -162,8 +166,7 @@ async def solve_dependencies(
_stack, AsyncExitStack
), "Generator dependency should be called in context"
if is_gen_callable(func):
cm = run_sync_ctx_manager(
contextmanager(func)(**sub_values))
cm = run_sync_ctx_manager(contextmanager(func)(**sub_values))
else:
cm = asynccontextmanager(func)(**sub_values)
solved = await _stack.enter_async_context(cm)
@ -182,19 +185,17 @@ async def solve_dependencies(
# usual dependency
for field in _dependent.params:
field_info = field.field_info
assert isinstance(field_info,
Param), "Params must be subclasses of Param"
assert isinstance(field_info, Param), "Params must be subclasses of Param"
value = field_info._solve(**params)
if value == Undefined:
value = field.get_default()
_, errs_ = field.validate(value,
values,
loc=(str(field_info), field.alias))
_, errs_ = field.validate(value, values, loc=(str(field_info), field.alias))
if errs_:
logger.debug(
f"{field_info} "
f"type {type(value)} not match depends {_dependent.func} "
f"annotation {field._type_display()}, ignored")
f"annotation {field._type_display()}, ignored"
)
raise SkippedException
else:
values[field.name] = value
@ -202,9 +203,7 @@ async def solve_dependencies(
return values, dependency_cache
def Depends(dependency: Optional[T_Handler] = None,
*,
use_cache: bool = True) -> Any:
def Depends(dependency: Optional[T_Handler] = None, *, use_cache: bool = True) -> Any:
"""
:说明:

View File

@ -9,7 +9,6 @@ from nonebot.typing import T_Handler
class Param(abc.ABC, FieldInfo):
@classmethod
@abc.abstractmethod
def _check(cls, name: str, param: inspect.Parameter) -> bool:
@ -21,11 +20,9 @@ class Param(abc.ABC, FieldInfo):
class DependsWrapper:
def __init__(self,
dependency: Optional[T_Handler] = None,
*,
use_cache: bool = True) -> None:
def __init__(
self, dependency: Optional[T_Handler] = None, *, use_cache: bool = True
) -> None:
self.dependency = dependency
self.use_cache = use_cache
@ -36,15 +33,16 @@ class DependsWrapper:
class Dependent:
def __init__(self,
*,
func: Optional[T_Handler] = None,
name: Optional[str] = None,
params: Optional[List[ModelField]] = None,
allow_types: Optional[List[Type[Param]]] = None,
dependencies: Optional[List["Dependent"]] = None,
use_cache: bool = True) -> None:
def __init__(
self,
*,
func: Optional[T_Handler] = None,
name: Optional[str] = None,
params: Optional[List[ModelField]] = None,
allow_types: Optional[List[Type[Param]]] = None,
dependencies: Optional[List["Dependent"]] = None,
use_cache: bool = True,
) -> None:
self.func = func
self.name = name
self.params = params or []

View File

@ -16,14 +16,14 @@ def get_typed_signature(func: T_Handler) -> inspect.Signature:
kind=param.kind,
default=param.default,
annotation=get_typed_annotation(param, globalns),
) for param in signature.parameters.values()
)
for param in signature.parameters.values()
]
typed_signature = inspect.Signature(typed_params)
return typed_signature
def get_typed_annotation(param: inspect.Parameter, globalns: Dict[str,
Any]) -> Any:
def get_typed_annotation(param: inspect.Parameter, globalns: Dict[str, Any]) -> Any:
annotation = param.annotation
if isinstance(annotation, str):
annotation = ForwardRef(annotation)
@ -31,7 +31,7 @@ def get_typed_annotation(param: inspect.Parameter, globalns: Dict[str,
annotation = evaluate_forwardref(annotation, globalns, globalns)
except Exception as e:
logger.opt(colors=True, exception=e).warning(
f"Unknown ForwardRef[\"{param.annotation}\"] for parameter {param.name}"
f'Unknown ForwardRef["{param.annotation}"] for parameter {param.name}'
)
return inspect.Parameter.empty
return annotation

View File

@ -8,8 +8,17 @@
import abc
import asyncio
from dataclasses import field, dataclass
from typing import (TYPE_CHECKING, Any, Set, Dict, Type, Union, Callable,
Optional, Awaitable)
from typing import (
TYPE_CHECKING,
Any,
Set,
Dict,
Type,
Union,
Callable,
Optional,
Awaitable,
)
from nonebot.log import logger
from nonebot.utils import escape_tag
@ -90,12 +99,14 @@ class Driver(abc.ABC):
"""
if name in self._adapters:
logger.opt(colors=True).debug(
f'Adapter "<y>{escape_tag(name)}</y>" already exists')
f'Adapter "<y>{escape_tag(name)}</y>" already exists'
)
return
self._adapters[name] = adapter
adapter.register(self, self.config, **kwargs)
logger.opt(colors=True).debug(
f'Succeeded to load adapter "<y>{escape_tag(name)}</y>"')
f'Succeeded to load adapter "<y>{escape_tag(name)}</y>"'
)
@property
@abc.abstractmethod
@ -121,7 +132,8 @@ class Driver(abc.ABC):
* ``**kwargs``
"""
logger.opt(colors=True).debug(
f"<g>Loaded adapters: {escape_tag(', '.join(self._adapters))}</g>")
f"<g>Loaded adapters: {escape_tag(', '.join(self._adapters))}</g>"
)
@abc.abstractmethod
def on_startup(self, func: Callable) -> Callable:
@ -146,8 +158,7 @@ class Driver(abc.ABC):
self._bot_connection_hook.add(func)
return func
def on_bot_disconnect(
self, func: T_BotDisconnectionHook) -> T_BotDisconnectionHook:
def on_bot_disconnect(self, func: T_BotDisconnectionHook) -> T_BotDisconnectionHook:
"""
:说明:
@ -172,7 +183,8 @@ class Driver(abc.ABC):
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running WebSocketConnection hook. "
"Running cancelled!</bg #f8bbd0></r>")
"Running cancelled!</bg #f8bbd0></r>"
)
asyncio.create_task(_run_hook(bot))
@ -189,7 +201,8 @@ class Driver(abc.ABC):
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running WebSocketDisConnection hook. "
"Running cancelled!</bg #f8bbd0></r>")
"Running cancelled!</bg #f8bbd0></r>"
)
asyncio.create_task(_run_hook(bot))
@ -201,8 +214,8 @@ class ForwardDriver(Driver):
@abc.abstractmethod
def setup_http_polling(
self, setup: Union["HTTPPollingSetup",
Callable[[], Awaitable["HTTPPollingSetup"]]]
self,
setup: Union["HTTPPollingSetup", Callable[[], Awaitable["HTTPPollingSetup"]]],
) -> None:
"""
:说明:
@ -217,8 +230,7 @@ class ForwardDriver(Driver):
@abc.abstractmethod
def setup_websocket(
self, setup: Union["WebSocketSetup",
Callable[[], Awaitable["WebSocketSetup"]]]
self, setup: Union["WebSocketSetup", Callable[[], Awaitable["WebSocketSetup"]]]
) -> None:
"""
:说明:
@ -288,6 +300,7 @@ class HTTPRequest(HTTPConnection):
.. _asgi http scope:
https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope
"""
method: str = "GET"
"""The HTTP method name, uppercased."""
body: bytes = b""
@ -309,6 +322,7 @@ class HTTPResponse:
.. _asgi http scope:
https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope
"""
status: int
"""HTTP status code."""
body: Optional[bytes] = None
@ -416,5 +430,5 @@ class WebSocketSetup:
"""URL"""
headers: Dict[str, str] = field(default_factory=dict)
"""HTTP headers"""
reconnect_interval: float = 3.
reconnect_interval: float = 3.0
"""WebSocket 重连间隔"""

View File

@ -20,13 +20,16 @@ from nonebot.typing import overrides
from nonebot.utils import escape_tag
from nonebot.config import Env, Config
from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.drivers import (HTTPRequest, ForwardDriver, WebSocketSetup,
HTTPPollingSetup)
from nonebot.drivers import (
HTTPRequest,
ForwardDriver,
WebSocketSetup,
HTTPPollingSetup,
)
STARTUP_FUNC = Callable[[], Awaitable[None]]
SHUTDOWN_FUNC = Callable[[], Awaitable[None]]
HTTPPOLLING_SETUP = Union[HTTPPollingSetup,
Callable[[], Awaitable[HTTPPollingSetup]]]
HTTPPOLLING_SETUP = Union[HTTPPollingSetup, Callable[[], Awaitable[HTTPPollingSetup]]]
WEBSOCKET_SETUP = Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]
HANDLED_SIGNALS = (
signal.SIGINT, # Unix signal 2. Sent by Ctrl+C.
@ -146,7 +149,8 @@ class Driver(ForwardDriver):
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running startup function. "
"Ignored!</bg #f8bbd0></r>")
"Ignored!</bg #f8bbd0></r>"
)
async def main_loop(self):
await self.should_exit.wait()
@ -160,24 +164,20 @@ class Driver(ForwardDriver):
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running shutdown function. "
"Ignored!</bg #f8bbd0></r>")
"Ignored!</bg #f8bbd0></r>"
)
for task in self.connections:
if not task.done():
task.cancel()
await asyncio.sleep(0.1)
tasks = [
t for t in asyncio.all_tasks() if t is not asyncio.current_task()
]
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
if tasks and not self.force_exit:
logger.info("Waiting for tasks to finish. (CTRL+C to force quit)")
while tasks and not self.force_exit:
await asyncio.sleep(0.1)
tasks = [
t for t in asyncio.all_tasks()
if t is not asyncio.current_task()
]
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
for task in tasks:
task.cancel()
@ -209,9 +209,7 @@ class Driver(ForwardDriver):
self.should_exit.set()
async def _http_loop(self, setup: HTTPPOLLING_SETUP):
async def _build_request(
setup: HTTPPollingSetup) -> Optional[HTTPRequest]:
async def _build_request(setup: HTTPPollingSetup) -> Optional[HTTPRequest]:
url = URL(setup.url)
if not url.is_absolute() or not url.host:
logger.opt(colors=True).error(
@ -219,10 +217,15 @@ class Driver(ForwardDriver):
)
return
host = f"{url.host}:{url.port}" if url.port else url.host
return HTTPRequest(setup.http_version, url.scheme, url.path,
url.raw_query_string.encode("latin-1"), {
**setup.headers, "host": host
}, setup.method, setup.body)
return HTTPRequest(
setup.http_version,
url.scheme,
url.path,
url.raw_query_string.encode("latin-1"),
{**setup.headers, "host": host},
setup.method,
setup.body,
)
bot: Optional[Bot] = None
request: Optional[HTTPRequest] = None
@ -230,7 +233,8 @@ class Driver(ForwardDriver):
logger.opt(colors=True).info(
f"Start http polling for <y>{escape_tag(setup.adapter.upper())} "
f"Bot {escape_tag(setup.self_id)}</y>")
f"Bot {escape_tag(setup.self_id)}</y>"
)
try:
async with aiohttp.ClientSession() as session:
@ -244,7 +248,8 @@ class Driver(ForwardDriver):
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error while parsing setup "
f"{escape_tag(repr(setup))}.</bg #f8bbd0></r>")
f"{escape_tag(repr(setup))}.</bg #f8bbd0></r>"
)
await asyncio.sleep(3)
continue
@ -286,19 +291,22 @@ class Driver(ForwardDriver):
)
try:
async with session.request(request.method,
setup_.url,
data=request.body,
headers=headers,
timeout=timeout,
version=version) as response:
async with session.request(
request.method,
setup_.url,
data=request.body,
headers=headers,
timeout=timeout,
version=version,
) as response:
response.raise_for_status()
data = await response.read()
asyncio.create_task(bot.handle_message(data))
except aiohttp.ClientResponseError as e:
logger.opt(colors=True, exception=e).error(
f"<r><bg #f8bbd0>Error occurred while requesting {escape_tag(setup_.url)}. "
"Try to reconnect...</bg #f8bbd0></r>")
"Try to reconnect...</bg #f8bbd0></r>"
)
await asyncio.sleep(setup_.poll_interval)
@ -307,7 +315,8 @@ class Driver(ForwardDriver):
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Unexpected exception occurred "
"while http polling</bg #f8bbd0></r>")
"while http polling</bg #f8bbd0></r>"
)
finally:
if bot:
self._bot_disconnect(bot)
@ -327,7 +336,8 @@ class Driver(ForwardDriver):
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error while parsing setup "
f"{escape_tag(repr(setup))}.</bg #f8bbd0></r>")
f"{escape_tag(repr(setup))}.</bg #f8bbd0></r>"
)
await asyncio.sleep(3)
continue
@ -346,17 +356,21 @@ class Driver(ForwardDriver):
f"Bot {setup_.self_id} from adapter {setup_.adapter} connecting to {url}"
)
try:
async with session.ws_connect(url,
headers=headers,
timeout=30.) as ws:
async with session.ws_connect(
url, headers=headers, timeout=30.0
) as ws:
logger.opt(colors=True).info(
f"WebSocket Connection to <y>{escape_tag(setup_.adapter.upper())} "
f"Bot {escape_tag(setup_.self_id)}</y> succeeded!"
)
request = WebSocket(
"1.1", url.scheme, url.path,
url.raw_query_string.encode("latin-1"), headers,
ws)
"1.1",
url.scheme,
url.path,
url.raw_query_string.encode("latin-1"),
headers,
ws,
)
BotClass = self._adapters[setup_.adapter]
bot = BotClass(setup_.self_id, request)
@ -365,25 +379,30 @@ class Driver(ForwardDriver):
msg = await ws.receive()
if msg.type == aiohttp.WSMsgType.text:
asyncio.create_task(
bot.handle_message(msg.data.encode()))
bot.handle_message(msg.data.encode())
)
elif msg.type == aiohttp.WSMsgType.binary:
asyncio.create_task(
bot.handle_message(msg.data))
asyncio.create_task(bot.handle_message(msg.data))
elif msg.type == aiohttp.WSMsgType.error:
logger.opt(colors=True).error(
"<r><bg #f8bbd0>Error while handling websocket frame. "
"Try to reconnect...</bg #f8bbd0></r>")
"Try to reconnect...</bg #f8bbd0></r>"
)
break
else:
logger.opt(colors=True).error(
"<r><bg #f8bbd0>WebSocket connection closed by peer. "
"Try to reconnect...</bg #f8bbd0></r>")
"Try to reconnect...</bg #f8bbd0></r>"
)
break
except (aiohttp.ClientResponseError,
aiohttp.ClientConnectionError) as e:
except (
aiohttp.ClientResponseError,
aiohttp.ClientConnectionError,
) as e:
logger.opt(colors=True, exception=e).error(
f"<r><bg #f8bbd0>Error while connecting to {escape_tag(str(url))}. "
"Try to reconnect...</bg #f8bbd0></r>")
"Try to reconnect...</bg #f8bbd0></r>"
)
finally:
if bot:
self._bot_disconnect(bot)
@ -395,7 +414,8 @@ class Driver(ForwardDriver):
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Unexpected exception occurred "
"while websocket loop</bg #f8bbd0></r>")
"while websocket loop</bg #f8bbd0></r>"
)
@dataclass

View File

@ -32,11 +32,15 @@ from nonebot.typing import overrides
from nonebot.utils import escape_tag
from nonebot.config import Config as NoneBotConfig
from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.drivers import (HTTPRequest, ForwardDriver, ReverseDriver,
WebSocketSetup, HTTPPollingSetup)
from nonebot.drivers import (
HTTPRequest,
ForwardDriver,
ReverseDriver,
WebSocketSetup,
HTTPPollingSetup,
)
HTTPPOLLING_SETUP = Union[HTTPPollingSetup,
Callable[[], Awaitable[HTTPPollingSetup]]]
HTTPPOLLING_SETUP = Union[HTTPPollingSetup, Callable[[], Awaitable[HTTPPollingSetup]]]
WEBSOCKET_SETUP = Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]
@ -44,6 +48,7 @@ class Config(BaseSettings):
"""
FastAPI 驱动框架设置,详情参考 FastAPI 文档
"""
fastapi_openapi_url: Optional[str] = None
"""
:类型:
@ -226,12 +231,14 @@ class Driver(ReverseDriver, ForwardDriver):
self.websockets.append(setup)
@overrides(ReverseDriver)
def run(self,
host: Optional[str] = None,
port: Optional[int] = None,
*,
app: Optional[str] = None,
**kwargs):
def run(
self,
host: Optional[str] = None,
port: Optional[int] = None,
*,
app: Optional[str] = None,
**kwargs,
):
"""使用 ``uvicorn`` 启动 FastAPI"""
super().run(host, port, app, **kwargs)
LOGGING_CONFIG = {
@ -243,10 +250,7 @@ class Driver(ReverseDriver, ForwardDriver):
},
},
"loggers": {
"uvicorn.error": {
"handlers": ["default"],
"level": "INFO"
},
"uvicorn.error": {"handlers": ["default"], "level": "INFO"},
"uvicorn.access": {
"handlers": ["default"],
"level": "INFO",
@ -258,15 +262,16 @@ class Driver(ReverseDriver, ForwardDriver):
host=host or str(self.config.host),
port=port or self.config.port,
reload=self.fastapi_config.fastapi_reload
if self.fastapi_config.fastapi_reload is not None else
(bool(app) and self.config.debug),
if self.fastapi_config.fastapi_reload is not None
else (bool(app) and self.config.debug),
reload_dirs=self.fastapi_config.fastapi_reload_dirs,
reload_delay=self.fastapi_config.fastapi_reload_delay,
reload_includes=self.fastapi_config.fastapi_reload_includes,
reload_excludes=self.fastapi_config.fastapi_reload_excludes,
debug=self.config.debug,
log_config=LOGGING_CONFIG,
**kwargs)
**kwargs,
)
def _run_forward(self):
for setup in self.http_pollings:
@ -287,39 +292,49 @@ class Driver(ReverseDriver, ForwardDriver):
logger.warning(
f"Unknown adapter {adapter}. Please register the adapter before use."
)
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND,
detail="adapter not found")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="adapter not found"
)
# 创建 Bot 对象
BotClass = self._adapters[adapter]
http_request = HTTPRequest(request.scope["http_version"],
request.url.scheme, request.url.path,
request.scope["query_string"],
dict(request.headers), request.method, data)
x_self_id, response = await BotClass.check_permission(
self, http_request)
http_request = HTTPRequest(
request.scope["http_version"],
request.url.scheme,
request.url.path,
request.scope["query_string"],
dict(request.headers),
request.method,
data,
)
x_self_id, response = await BotClass.check_permission(self, http_request)
if not x_self_id:
raise HTTPException(
response and response.status or 401, response and
response.body and response.body.decode("utf-8"))
response and response.status or 401,
response and response.body and response.body.decode("utf-8"),
)
if x_self_id in self._clients:
logger.warning("There's already a reverse websocket connection,"
"so the event may be handled twice.")
logger.warning(
"There's already a reverse websocket connection,"
"so the event may be handled twice."
)
bot = BotClass(x_self_id, http_request)
asyncio.create_task(bot.handle_message(data))
return Response(response and response.body,
response and response.status or 200)
return Response(response and response.body, response and response.status or 200)
async def _handle_ws_reverse(self, adapter: str,
websocket: FastAPIWebSocket):
ws = WebSocket(websocket.scope.get("http_version",
"1.1"), websocket.url.scheme,
websocket.url.path, websocket.scope["query_string"],
dict(websocket.headers), websocket)
async def _handle_ws_reverse(self, adapter: str, websocket: FastAPIWebSocket):
ws = WebSocket(
websocket.scope.get("http_version", "1.1"),
websocket.url.scheme,
websocket.url.path,
websocket.scope["query_string"],
dict(websocket.headers),
websocket,
)
if adapter not in self._adapters:
logger.warning(
@ -349,7 +364,8 @@ class Driver(ReverseDriver, ForwardDriver):
await ws.accept()
logger.opt(colors=True).info(
f"WebSocket Connection from <y>{escape_tag(adapter.upper())} "
f"Bot {escape_tag(self_id)}</y> Accepted!")
f"Bot {escape_tag(self_id)}</y> Accepted!"
)
self._bot_connect(bot)
@ -362,7 +378,8 @@ class Driver(ReverseDriver, ForwardDriver):
break
except Exception as e:
logger.opt(exception=e).error(
"Error when receiving data from websocket.")
"Error when receiving data from websocket."
)
break
asyncio.create_task(bot.handle_message(data.encode()))
@ -370,9 +387,7 @@ class Driver(ReverseDriver, ForwardDriver):
self._bot_disconnect(bot)
async def _http_loop(self, setup: HTTPPOLLING_SETUP):
async def _build_request(
setup: HTTPPollingSetup) -> Optional[HTTPRequest]:
async def _build_request(setup: HTTPPollingSetup) -> Optional[HTTPRequest]:
url = httpx.URL(setup.url)
if not url.netloc:
logger.opt(colors=True).error(
@ -380,9 +395,14 @@ class Driver(ReverseDriver, ForwardDriver):
)
return
return HTTPRequest(
setup.http_version, url.scheme, url.path, url.query, {
**setup.headers, "host": url.netloc.decode("ascii")
}, setup.method, setup.body)
setup.http_version,
url.scheme,
url.path,
url.query,
{**setup.headers, "host": url.netloc.decode("ascii")},
setup.method,
setup.body,
)
bot: Optional[Bot] = None
request: Optional[HTTPRequest] = None
@ -390,11 +410,11 @@ class Driver(ReverseDriver, ForwardDriver):
logger.opt(colors=True).info(
f"Start http polling for <y>{escape_tag(setup.adapter.upper())} "
f"Bot {escape_tag(setup.self_id)}</y>")
f"Bot {escape_tag(setup.self_id)}</y>"
)
try:
async with httpx.AsyncClient(http2=True,
follow_redirects=True) as session:
async with httpx.AsyncClient(http2=True, follow_redirects=True) as session:
while not self.shutdown.is_set():
try:
@ -405,7 +425,8 @@ class Driver(ReverseDriver, ForwardDriver):
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error while parsing setup "
f"{escape_tag(repr(setup))}.</bg #f8bbd0></r>")
f"{escape_tag(repr(setup))}.</bg #f8bbd0></r>"
)
await asyncio.sleep(3)
continue
@ -432,18 +453,21 @@ class Driver(ReverseDriver, ForwardDriver):
f"Bot {setup_.self_id} from adapter {setup_.adapter} request {setup_.url}"
)
try:
response = await session.request(request.method,
setup_.url,
content=request.body,
headers=headers,
timeout=30.)
response = await session.request(
request.method,
setup_.url,
content=request.body,
headers=headers,
timeout=30.0,
)
response.raise_for_status()
data = response.read()
asyncio.create_task(bot.handle_message(data))
except httpx.HTTPError as e:
logger.opt(colors=True, exception=e).error(
f"<r><bg #f8bbd0>Error occurred while requesting {escape_tag(setup_.url)}. "
"Try to reconnect...</bg #f8bbd0></r>")
"Try to reconnect...</bg #f8bbd0></r>"
)
await asyncio.sleep(setup_.poll_interval)
@ -452,7 +476,8 @@ class Driver(ReverseDriver, ForwardDriver):
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Unexpected exception occurred "
"while http polling</bg #f8bbd0></r>")
"while http polling</bg #f8bbd0></r>"
)
finally:
if bot:
self._bot_disconnect(bot)
@ -471,7 +496,8 @@ class Driver(ReverseDriver, ForwardDriver):
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error while parsing setup "
f"{escape_tag(repr(setup))}.</bg #f8bbd0></r>")
f"{escape_tag(repr(setup))}.</bg #f8bbd0></r>"
)
await asyncio.sleep(3)
continue
@ -491,9 +517,11 @@ class Driver(ReverseDriver, ForwardDriver):
async with connection as ws:
logger.opt(colors=True).info(
f"WebSocket Connection to <y>{escape_tag(setup_.adapter.upper())} "
f"Bot {escape_tag(setup_.self_id)}</y> succeeded!")
request = WebSocket("1.1", url.scheme, url.path,
url.query, headers, ws)
f"Bot {escape_tag(setup_.self_id)}</y> succeeded!"
)
request = WebSocket(
"1.1", url.scheme, url.path, url.query, headers, ws
)
BotClass = self._adapters[setup_.adapter]
bot = BotClass(setup_.self_id, request)
@ -506,12 +534,14 @@ class Driver(ReverseDriver, ForwardDriver):
except ConnectionClosed:
logger.opt(colors=True).error(
"<r><bg #f8bbd0>WebSocket connection closed by peer. "
"Try to reconnect...</bg #f8bbd0></r>")
"Try to reconnect...</bg #f8bbd0></r>"
)
break
except Exception as e:
logger.opt(colors=True, exception=e).error(
f"<r><bg #f8bbd0>Error while connecting to {url}. "
"Try to reconnect...</bg #f8bbd0></r>")
"Try to reconnect...</bg #f8bbd0></r>"
)
finally:
if bot:
self._bot_disconnect(bot)
@ -523,21 +553,22 @@ class Driver(ReverseDriver, ForwardDriver):
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Unexpected exception occurred "
"while websocket loop</bg #f8bbd0></r>")
"while websocket loop</bg #f8bbd0></r>"
)
@dataclass
class WebSocket(BaseWebSocket):
websocket: Union[FastAPIWebSocket,
WebSocketClientProtocol] = None # type: ignore
websocket: Union[FastAPIWebSocket, WebSocketClientProtocol] = None # type: ignore
@property
@overrides(BaseWebSocket)
def closed(self) -> bool:
if isinstance(self.websocket, FastAPIWebSocket):
return (
self.websocket.client_state == WebSocketState.DISCONNECTED or
self.websocket.application_state == WebSocketState.DISCONNECTED)
self.websocket.client_state == WebSocketState.DISCONNECTED
or self.websocket.application_state == WebSocketState.DISCONNECTED
)
else:
return self.websocket.closed

View File

@ -30,8 +30,7 @@ try:
from quart import Quart, Request, Response
from quart import Websocket as QuartWebSocket
except ImportError:
raise ValueError(
'Please install Quart by using `pip install nonebot2[quart]`')
raise ValueError("Please install Quart by using `pip install nonebot2[quart]`")
_AsyncCallable = TypeVar("_AsyncCallable", bound=Callable[..., Coroutine])
@ -40,6 +39,7 @@ class Config(BaseSettings):
"""
Quart 驱动框架设置
"""
quart_reload: Optional[bool] = None
"""
:类型:
@ -111,11 +111,12 @@ class Driver(ReverseDriver):
self.quart_config = Config(**config.dict())
self._server_app = Quart(self.__class__.__qualname__)
self._server_app.add_url_rule("/<adapter>/http",
methods=["POST"],
view_func=self._handle_http)
self._server_app.add_websocket("/<adapter>/ws",
view_func=self._handle_ws_reverse)
self._server_app.add_url_rule(
"/<adapter>/http", methods=["POST"], view_func=self._handle_http
)
self._server_app.add_websocket(
"/<adapter>/ws", view_func=self._handle_ws_reverse
)
@property
@overrides(ReverseDriver)
@ -156,12 +157,14 @@ class Driver(ReverseDriver):
return self.server_app.after_serving(func) # type: ignore
@overrides(ReverseDriver)
def run(self,
host: Optional[str] = None,
port: Optional[int] = None,
*,
app: Optional[str] = None,
**kwargs):
def run(
self,
host: Optional[str] = None,
port: Optional[int] = None,
*,
app: Optional[str] = None,
**kwargs,
):
"""使用 ``uvicorn`` 启动 Quart"""
super().run(host, port, app, **kwargs)
LOGGING_CONFIG = {
@ -173,10 +176,7 @@ class Driver(ReverseDriver):
},
},
"loggers": {
"uvicorn.error": {
"handlers": ["default"],
"level": "INFO"
},
"uvicorn.error": {"handlers": ["default"], "level": "INFO"},
"uvicorn.access": {
"handlers": ["default"],
"level": "INFO",
@ -188,52 +188,69 @@ class Driver(ReverseDriver):
host=host or str(self.config.host),
port=port or self.config.port,
reload=self.quart_config.quart_reload
if self.quart_config.quart_reload is not None else
(bool(app) and self.config.debug),
if self.quart_config.quart_reload is not None
else (bool(app) and self.config.debug),
reload_dirs=self.quart_config.quart_reload_dirs,
reload_delay=self.quart_config.quart_reload_delay,
reload_includes=self.quart_config.quart_reload_includes,
reload_excludes=self.quart_config.quart_reload_excludes,
debug=self.config.debug,
log_config=LOGGING_CONFIG,
**kwargs)
**kwargs,
)
async def _handle_http(self, adapter: str):
request: Request = _request
data: bytes = await request.get_data() # type: ignore
if adapter not in self._adapters:
logger.warning(f'Unknown adapter {adapter}. '
'Please register the adapter before use.')
logger.warning(
f"Unknown adapter {adapter}. " "Please register the adapter before use."
)
raise exceptions.NotFound()
BotClass = self._adapters[adapter]
http_request = HTTPRequest(request.http_version, request.scheme,
request.path, request.query_string,
dict(request.headers), request.method, data)
http_request = HTTPRequest(
request.http_version,
request.scheme,
request.path,
request.query_string,
dict(request.headers),
request.method,
data,
)
self_id, response = await BotClass.check_permission(self, http_request)
if not self_id:
raise exceptions.Unauthorized(
description=(response and response.body or b"").decode())
description=(response and response.body or b"").decode()
)
if self_id in self._clients:
logger.warning("There's already a reverse websocket connection,"
"so the event may be handled twice.")
logger.warning(
"There's already a reverse websocket connection,"
"so the event may be handled twice."
)
bot = BotClass(self_id, http_request)
asyncio.create_task(bot.handle_message(data))
return Response(response and response.body or "",
response and response.status or 200)
return Response(
response and response.body or "", response and response.status or 200
)
async def _handle_ws_reverse(self, adapter: str):
websocket: QuartWebSocket = _websocket
ws = WebSocket(websocket.http_version, websocket.scheme,
websocket.path, websocket.query_string,
dict(websocket.headers), websocket)
ws = WebSocket(
websocket.http_version,
websocket.scheme,
websocket.path,
websocket.query_string,
dict(websocket.headers),
websocket,
)
if adapter not in self._adapters:
logger.warning(
f'Unknown adapter {adapter}. Please register the adapter before use.'
f"Unknown adapter {adapter}. Please register the adapter before use."
)
raise exceptions.NotFound()
@ -242,20 +259,22 @@ class Driver(ReverseDriver):
if not self_id:
raise exceptions.Unauthorized(
description=(response and response.body or b"").decode())
description=(response and response.body or b"").decode()
)
if self_id in self._clients:
logger.opt(colors=True).warning(
"There's already a websocket connection, "
f"<y>{escape_tag(adapter.upper())} Bot {escape_tag(self_id)}</y> ignored."
)
raise exceptions.Forbidden(description='Client already exists.')
raise exceptions.Forbidden(description="Client already exists.")
bot = BotClass(self_id, ws)
await ws.accept()
logger.opt(colors=True).info(
f"WebSocket Connection from <y>{escape_tag(adapter.upper())} "
f"Bot {escape_tag(self_id)}</y> Accepted!")
f"Bot {escape_tag(self_id)}</y> Accepted!"
)
self._bot_connect(bot)
try:
@ -267,7 +286,8 @@ class Driver(ReverseDriver):
break
except Exception as e:
logger.opt(exception=e).error(
"Error when receiving data from websocket.")
"Error when receiving data from websocket."
)
break
asyncio.create_task(bot.handle_message(data.encode()))

View File

@ -157,6 +157,7 @@ class NoLogException(AdapterException):
指示 NoneBot 对当前 ``Event`` 进行处理但不显示 Log 信息,可在 ``get_log_string`` 时抛出
"""
pass
@ -166,6 +167,7 @@ class ApiNotAvailable(AdapterException):
在 API 连接不可用时抛出。
"""
pass
@ -175,6 +177,7 @@ class NetworkError(AdapterException):
在网络出现问题时抛出,如: API 请求地址不正确, API 请求无返回或返回状态非正常等。
"""
pass
@ -184,4 +187,5 @@ class ActionFailed(AdapterException):
API 请求成功返回数据,但 API 操作失败。
"""
pass

View File

@ -10,20 +10,27 @@ from contextlib import AsyncExitStack
from typing import Any, Dict, List, Type, Callable, Optional
from nonebot.utils import get_name, run_sync
from nonebot.dependencies import (Param, Dependent, DependsWrapper,
get_dependent, solve_dependencies,
get_parameterless_sub_dependant)
from nonebot.dependencies import (
Param,
Dependent,
DependsWrapper,
get_dependent,
solve_dependencies,
get_parameterless_sub_dependant,
)
class Handler:
"""事件处理器类。支持依赖注入。"""
def __init__(self,
func: Callable[..., Any],
*,
name: Optional[str] = None,
dependencies: Optional[List[DependsWrapper]] = None,
allow_types: Optional[List[Type[Param]]] = None):
def __init__(
self,
func: Callable[..., Any],
*,
name: Optional[str] = None,
dependencies: Optional[List[DependsWrapper]] = None,
allow_types: Optional[List[Type[Param]]] = None,
):
"""
:说明:
@ -64,19 +71,18 @@ class Handler:
self.dependent = get_dependent(func=func, allow_types=self.allow_types)
def __repr__(self) -> str:
return (
f"<Handler {self.name}({', '.join(map(str, self.dependent.params))})>"
)
return f"<Handler {self.name}({', '.join(map(str, self.dependent.params))})>"
def __str__(self) -> str:
return repr(self)
async def __call__(self,
*,
_stack: Optional[AsyncExitStack] = None,
_dependency_cache: Optional[Dict[Callable[..., Any],
Any]] = None,
**params) -> Any:
async def __call__(
self,
*,
_stack: Optional[AsyncExitStack] = None,
_dependency_cache: Optional[Dict[Callable[..., Any], Any]] = None,
**params,
) -> Any:
values, _ = await solve_dependencies(
_dependent=self.dependent,
_stack=_stack,
@ -85,7 +91,8 @@ class Handler:
for dependency in self.dependencies
],
_dependency_cache=_dependency_cache,
**params)
**params,
)
if asyncio.iscoroutinefunction(self.func):
return await self.func(**values)
@ -98,7 +105,8 @@ class Handler:
if dependency.dependency in self.sub_dependents:
raise ValueError(f"{dependency} is already in dependencies")
sub_dependant = get_parameterless_sub_dependant(
depends=dependency, allow_types=self.allow_types)
depends=dependency, allow_types=self.allow_types
)
self.sub_dependents[dependency.dependency] = sub_dependant
def prepend_dependency(self, dependency: DependsWrapper):

View File

@ -48,7 +48,6 @@ logger: "Logger" = loguru.logger
class Filter:
def __init__(self) -> None:
self.level: Union[int, str] = "DEBUG"
@ -58,13 +57,13 @@ class Filter:
if module:
module_name = getattr(module, "__module_name__", module_name)
record["name"] = module_name.split(".")[0]
levelno = logger.level(self.level).no if isinstance(self.level,
str) else self.level
levelno = (
logger.level(self.level).no if isinstance(self.level, str) else self.level
)
return record["level"].no >= levelno
class LoguruHandler(logging.Handler):
def emit(self, record):
try:
level = logger.level(record.levelname).name
@ -76,8 +75,9 @@ class LoguruHandler(logging.Handler):
frame = frame.f_back
depth += 1
logger.opt(depth=depth,
exception=record.exc_info).log(level, record.getMessage())
logger.opt(depth=depth, exception=record.exc_info).log(
level, record.getMessage()
)
logger.remove()
@ -87,9 +87,12 @@ default_format = (
"[<lvl>{level}</lvl>] "
"<c><u>{name}</u></c> | "
# "<c>{function}:{line}</c>| "
"{message}")
logger_id = logger.add(sys.stdout,
colorize=True,
diagnose=False,
filter=default_filter,
format=default_format)
"{message}"
)
logger_id = logger.add(
sys.stdout,
colorize=True,
diagnose=False,
filter=default_filter,
format=default_format,
)

View File

@ -10,8 +10,17 @@ from datetime import datetime
from contextvars import ContextVar
from collections import defaultdict
from contextlib import AsyncExitStack
from typing import (TYPE_CHECKING, Any, Dict, List, Type, Union, Callable,
NoReturn, Optional)
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Type,
Union,
Callable,
NoReturn,
Optional,
)
from nonebot import params
from nonebot.rule import Rule
@ -19,14 +28,29 @@ from nonebot.log import logger
from nonebot.handler import Handler
from nonebot.dependencies import DependsWrapper
from nonebot.permission import USER, Permission
from nonebot.adapters import (Bot, Event, Message, MessageSegment,
MessageTemplate)
from nonebot.exception import (PausedException, StopPropagation,
SkippedException, FinishedException,
RejectedException)
from nonebot.typing import (T_State, T_Handler, T_ArgsParser, T_TypeUpdater,
T_StateFactory, T_DependencyCache,
T_PermissionUpdater)
from nonebot.adapters import (
Bot,
Event,
Message,
MessageSegment,
MessageTemplate,
)
from nonebot.exception import (
PausedException,
StopPropagation,
SkippedException,
FinishedException,
RejectedException,
)
from nonebot.typing import (
T_State,
T_Handler,
T_ArgsParser,
T_TypeUpdater,
T_StateFactory,
T_DependencyCache,
T_PermissionUpdater,
)
if TYPE_CHECKING:
from nonebot.plugin import Plugin
@ -57,9 +81,11 @@ class MatcherMeta(type):
expire_time: Optional[datetime]
def __repr__(self) -> str:
return (f"<Matcher from {self.module_name or 'unknown'}, "
f"type={self.type}, priority={self.priority}, "
f"temp={self.temp}>")
return (
f"<Matcher from {self.module_name or 'unknown'}, "
f"type={self.type}, priority={self.priority}, "
f"temp={self.temp}>"
)
def __str__(self) -> str:
return repr(self)
@ -67,6 +93,7 @@ class MatcherMeta(type):
class Matcher(metaclass=MatcherMeta):
"""事件响应器类"""
plugin: Optional["Plugin"] = None
"""
:类型: ``Optional[Plugin]``
@ -157,8 +184,11 @@ class Matcher(metaclass=MatcherMeta):
"""
HANDLER_PARAM_TYPES = [
params.BotParam, params.EventParam, params.StateParam,
params.MatcherParam, params.DefaultParam
params.BotParam,
params.EventParam,
params.StateParam,
params.MatcherParam,
params.DefaultParam,
]
def __init__(self):
@ -169,7 +199,8 @@ class Matcher(metaclass=MatcherMeta):
def __repr__(self) -> str:
return (
f"<Matcher from {self.module_name or 'unknown'}, type={self.type}, "
f"priority={self.priority}, temp={self.temp}>")
f"priority={self.priority}, temp={self.temp}>"
)
def __str__(self) -> str:
return repr(self)
@ -180,8 +211,9 @@ class Matcher(metaclass=MatcherMeta):
type_: str = "",
rule: Optional[Rule] = None,
permission: Optional[Permission] = None,
handlers: Optional[Union[List[T_Handler], List[Handler],
List[Union[T_Handler, Handler]]]] = None,
handlers: Optional[
Union[List[T_Handler], List[Handler], List[Union[T_Handler, Handler]]]
] = None,
temp: bool = False,
priority: int = 1,
block: bool = False,
@ -193,7 +225,7 @@ class Matcher(metaclass=MatcherMeta):
default_state_factory: Optional[T_StateFactory] = None,
default_parser: Optional[T_ArgsParser] = None,
default_type_updater: Optional[T_TypeUpdater] = None,
default_permission_updater: Optional[T_PermissionUpdater] = None
default_permission_updater: Optional[T_PermissionUpdater] = None,
) -> Type["Matcher"]:
"""
:说明:
@ -221,46 +253,37 @@ class Matcher(metaclass=MatcherMeta):
"""
NewMatcher = type(
"Matcher", (Matcher,), {
"plugin":
plugin,
"module":
module,
"plugin_name":
plugin and plugin.name,
"module_name":
module and module.__name__,
"type":
type_,
"rule":
rule or Rule(),
"permission":
permission or Permission(),
"Matcher",
(Matcher,),
{
"plugin": plugin,
"module": module,
"plugin_name": plugin and plugin.name,
"module_name": module and module.__name__,
"type": type_,
"rule": rule or Rule(),
"permission": permission or Permission(),
"handlers": [
handler if isinstance(handler, Handler) else Handler(
handler, allow_types=cls.HANDLER_PARAM_TYPES)
handler
if isinstance(handler, Handler)
else Handler(handler, allow_types=cls.HANDLER_PARAM_TYPES)
for handler in handlers
] if handlers else [],
"temp":
temp,
"expire_time":
expire_time,
"priority":
priority,
"block":
block,
"_default_state":
default_state or {},
"_default_state_factory":
staticmethod(default_state_factory)
if default_state_factory else None,
"_default_parser":
default_parser,
"_default_type_updater":
default_type_updater,
"_default_permission_updater":
default_permission_updater
})
]
if handlers
else [],
"temp": temp,
"expire_time": expire_time,
"priority": priority,
"block": block,
"_default_state": default_state or {},
"_default_state_factory": staticmethod(default_state_factory)
if default_state_factory
else None,
"_default_parser": default_parser,
"_default_type_updater": default_type_updater,
"_default_permission_updater": default_permission_updater,
},
)
matchers[priority].append(NewMatcher)
@ -272,8 +295,8 @@ class Matcher(metaclass=MatcherMeta):
bot: Bot,
event: Event,
stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[Dict[Callable[..., Any],
Any]] = None) -> bool:
dependency_cache: Optional[Dict[Callable[..., Any], Any]] = None,
) -> bool:
"""
:说明:
@ -289,8 +312,9 @@ class Matcher(metaclass=MatcherMeta):
- ``bool``: 是否满足权限
"""
event_type = event.get_type()
return (event_type == (cls.type or event_type) and
await cls.permission(bot, event, stack, dependency_cache))
return event_type == (cls.type or event_type) and await cls.permission(
bot, event, stack, dependency_cache
)
@classmethod
async def check_rule(
@ -299,8 +323,8 @@ class Matcher(metaclass=MatcherMeta):
event: Event,
state: T_State,
stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[Dict[Callable[..., Any],
Any]] = None) -> bool:
dependency_cache: Optional[Dict[Callable[..., Any], Any]] = None,
) -> bool:
"""
:说明:
@ -317,8 +341,9 @@ class Matcher(metaclass=MatcherMeta):
- ``bool``: 是否满足匹配规则
"""
event_type = event.get_type()
return (event_type == (cls.type or event_type) and
await cls.rule(bot, event, state, stack, dependency_cache))
return event_type == (cls.type or event_type) and await cls.rule(
bot, event, state, stack, dependency_cache
)
@classmethod
def args_parser(cls, func: T_ArgsParser) -> T_ArgsParser:
@ -349,8 +374,7 @@ class Matcher(metaclass=MatcherMeta):
return func
@classmethod
def permission_updater(cls,
func: T_PermissionUpdater) -> T_PermissionUpdater:
def permission_updater(cls, func: T_PermissionUpdater) -> T_PermissionUpdater:
"""
:说明:
@ -365,12 +389,11 @@ class Matcher(metaclass=MatcherMeta):
@classmethod
def append_handler(
cls,
handler: T_Handler,
dependencies: Optional[List[DependsWrapper]] = None) -> Handler:
handler_ = Handler(handler,
dependencies=dependencies,
allow_types=cls.HANDLER_PARAM_TYPES)
cls, handler: T_Handler, dependencies: Optional[List[DependsWrapper]] = None
) -> Handler:
handler_ = Handler(
handler, dependencies=dependencies, allow_types=cls.HANDLER_PARAM_TYPES
)
cls.handlers.append(handler_)
return handler_
@ -418,8 +441,7 @@ class Matcher(metaclass=MatcherMeta):
func_handler = cls.handlers[-1]
func_handler.prepend_dependency(depend)
else:
cls.append_handler(
func, dependencies=[depend] if cls.handlers else [])
cls.append_handler(func, dependencies=[depend] if cls.handlers else [])
return func
@ -429,9 +451,8 @@ class Matcher(metaclass=MatcherMeta):
def got(
cls,
key: str,
prompt: Optional[Union[str, Message, MessageSegment,
MessageTemplate]] = None,
args_parser: Optional[T_ArgsParser] = None
prompt: Optional[Union[str, Message, MessageSegment, MessageTemplate]] = None,
args_parser: Optional[T_ArgsParser] = None,
) -> Callable[[T_Handler], T_Handler]:
"""
:说明:
@ -483,16 +504,16 @@ class Matcher(metaclass=MatcherMeta):
func_handler.prepend_dependency(parser_depend)
func_handler.prepend_dependency(get_depend)
else:
cls.append_handler(func,
dependencies=[get_depend, parser_depend])
cls.append_handler(func, dependencies=[get_depend, parser_depend])
return func
return _decorator
@classmethod
async def send(cls, message: Union[str, Message, MessageSegment,
MessageTemplate], **kwargs) -> Any:
async def send(
cls, message: Union[str, Message, MessageSegment, MessageTemplate], **kwargs
) -> Any:
"""
:说明:
@ -513,10 +534,11 @@ class Matcher(metaclass=MatcherMeta):
return await bot.send(event=event, message=_message, **kwargs)
@classmethod
async def finish(cls,
message: Optional[Union[str, Message, MessageSegment,
MessageTemplate]] = None,
**kwargs) -> NoReturn:
async def finish(
cls,
message: Optional[Union[str, Message, MessageSegment, MessageTemplate]] = None,
**kwargs,
) -> NoReturn:
"""
:说明:
@ -539,10 +561,11 @@ class Matcher(metaclass=MatcherMeta):
raise FinishedException
@classmethod
async def pause(cls,
prompt: Optional[Union[str, Message, MessageSegment,
MessageTemplate]] = None,
**kwargs) -> NoReturn:
async def pause(
cls,
prompt: Optional[Union[str, Message, MessageSegment, MessageTemplate]] = None,
**kwargs,
) -> NoReturn:
"""
:说明:
@ -565,10 +588,9 @@ class Matcher(metaclass=MatcherMeta):
raise PausedException
@classmethod
async def reject(cls,
prompt: Optional[Union[str, Message,
MessageSegment]] = None,
**kwargs) -> NoReturn:
async def reject(
cls, prompt: Optional[Union[str, Message, MessageSegment]] = None, **kwargs
) -> NoReturn:
"""
:说明:
@ -601,31 +623,38 @@ class Matcher(metaclass=MatcherMeta):
self.block = True
# 运行handlers
async def run(self,
bot: Bot,
event: Event,
state: T_State,
stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[T_DependencyCache] = None):
async def run(
self,
bot: Bot,
event: Event,
state: T_State,
stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[T_DependencyCache] = None,
):
b_t = current_bot.set(bot)
e_t = current_event.set(event)
s_t = current_state.set(self.state)
try:
# Refresh preprocess state
self.state = await self._default_state_factory(
bot, event) if self._default_state_factory else self.state
self.state = (
await self._default_state_factory(bot, event)
if self._default_state_factory
else self.state
)
self.state.update(state)
while self.handlers:
handler = self.handlers.pop(0)
logger.debug(f"Running handler {handler}")
try:
await handler(matcher=self,
bot=bot,
event=event,
state=self.state,
_stack=stack,
_dependency_cache=dependency_cache)
await handler(
matcher=self,
bot=bot,
event=event,
state=self.state,
_stack=stack,
_dependency_cache=dependency_cache,
)
except SkippedException:
pass
@ -633,18 +662,13 @@ class Matcher(metaclass=MatcherMeta):
self.handlers.insert(0, handler) # type: ignore
updater = self.__class__._default_type_updater
if updater:
type_ = await updater(
bot,
event,
self.state, # type: ignore
self.type)
type_ = await updater(bot, event, self.state, self.type) # type: ignore
else:
type_ = "message"
updater = self.__class__._default_permission_updater
if updater:
permission = await updater(bot, event, self.state,
self.permission)
permission = await updater(bot, event, self.state, self.permission)
else:
permission = USER(event.get_session_id(), perm=self.permission)
@ -662,23 +686,18 @@ class Matcher(metaclass=MatcherMeta):
default_state=self.state,
default_parser=self.__class__._default_parser,
default_type_updater=self.__class__._default_type_updater,
default_permission_updater=self.__class__.
_default_permission_updater)
default_permission_updater=self.__class__._default_permission_updater,
)
except PausedException:
updater = self.__class__._default_type_updater
if updater:
type_ = await updater(
bot,
event,
self.state, # type: ignore
self.type)
type_ = await updater(bot, event, self.state, self.type) # type: ignore
else:
type_ = "message"
updater = self.__class__._default_permission_updater
if updater:
permission = await updater(bot, event, self.state,
self.permission)
permission = await updater(bot, event, self.state, self.permission)
else:
permission = USER(event.get_session_id(), perm=self.permission)
@ -696,8 +715,8 @@ class Matcher(metaclass=MatcherMeta):
default_state=self.state,
default_parser=self.__class__._default_parser,
default_type_updater=self.__class__._default_type_updater,
default_permission_updater=self.__class__.
_default_permission_updater)
default_permission_updater=self.__class__._default_permission_updater,
)
except FinishedException:
pass
except StopPropagation:

View File

@ -17,9 +17,14 @@ from nonebot.handler import Handler
from nonebot.utils import escape_tag
from nonebot.matcher import Matcher, matchers
from nonebot.exception import NoLogException, StopPropagation, IgnoredException
from nonebot.typing import (T_State, T_DependencyCache, T_RunPreProcessor,
T_RunPostProcessor, T_EventPreProcessor,
T_EventPostProcessor)
from nonebot.typing import (
T_State,
T_DependencyCache,
T_RunPreProcessor,
T_RunPostProcessor,
T_EventPreProcessor,
T_EventPostProcessor,
)
if TYPE_CHECKING:
from nonebot.adapters import Bot, Event
@ -30,15 +35,25 @@ _run_preprocessors: Set[Handler] = set()
_run_postprocessors: Set[Handler] = set()
EVENT_PCS_PARAMS = [
params.BotParam, params.EventParam, params.StateParam, params.DefaultParam
params.BotParam,
params.EventParam,
params.StateParam,
params.DefaultParam,
]
RUN_PREPCS_PARAMS = [
params.MatcherParam, params.BotParam, params.EventParam, params.StateParam,
params.DefaultParam
params.MatcherParam,
params.BotParam,
params.EventParam,
params.StateParam,
params.DefaultParam,
]
RUN_POSTPCS_PARAMS = [
params.MatcherParam, params.ExceptionParam, params.BotParam,
params.EventParam, params.StateParam, params.DefaultParam
params.MatcherParam,
params.ExceptionParam,
params.BotParam,
params.EventParam,
params.StateParam,
params.DefaultParam,
]
@ -83,13 +98,14 @@ def run_postprocessor(func: T_RunPostProcessor) -> T_RunPostProcessor:
async def _check_matcher(
priority: int,
Matcher: Type[Matcher],
bot: "Bot",
event: "Event",
state: T_State,
stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[T_DependencyCache] = None) -> None:
priority: int,
Matcher: Type[Matcher],
bot: "Bot",
event: "Event",
state: T_State,
stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[T_DependencyCache] = None,
) -> None:
if Matcher.expire_time and datetime.now() > Matcher.expire_time:
try:
matchers[priority].remove(Matcher)
@ -99,13 +115,13 @@ async def _check_matcher(
try:
if not await Matcher.check_perm(
bot, event, stack,
dependency_cache) or not await Matcher.check_rule(
bot, event, state, stack, dependency_cache):
bot, event, stack, dependency_cache
) or not await Matcher.check_rule(bot, event, state, stack, dependency_cache):
return
except Exception as e:
logger.opt(colors=True, exception=e).error(
f"<r><bg #f8bbd0>Rule check failed for {Matcher}.</bg #f8bbd0></r>")
f"<r><bg #f8bbd0>Rule check failed for {Matcher}.</bg #f8bbd0></r>"
)
return
if Matcher.temp:
@ -118,36 +134,43 @@ async def _check_matcher(
async def _run_matcher(
Matcher: Type[Matcher],
bot: "Bot",
event: "Event",
state: T_State,
stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[T_DependencyCache] = None) -> None:
Matcher: Type[Matcher],
bot: "Bot",
event: "Event",
state: T_State,
stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[T_DependencyCache] = None,
) -> None:
logger.info(f"Event will be handled by {Matcher}")
matcher = Matcher()
coros = list(
map(
lambda x: x(matcher=matcher,
bot=bot,
event=event,
state=state,
_stack=stack,
_dependency_cache=dependency_cache),
_run_preprocessors))
lambda x: x(
matcher=matcher,
bot=bot,
event=event,
state=state,
_stack=stack,
_dependency_cache=dependency_cache,
),
_run_preprocessors,
)
)
if coros:
try:
await asyncio.gather(*coros)
except IgnoredException:
logger.opt(colors=True).info(
f"Matcher {matcher} running is <b>cancelled</b>")
f"Matcher {matcher} running is <b>cancelled</b>"
)
return
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running RunPreProcessors. "
"Running cancelled!</bg #f8bbd0></r>")
"Running cancelled!</bg #f8bbd0></r>"
)
return
exception = None
@ -163,14 +186,18 @@ async def _run_matcher(
coros = list(
map(
lambda x: x(matcher=matcher,
exception=exception,
bot=bot,
event=event,
state=state,
_stack=stack,
_dependency_cache=dependency_cache),
_run_postprocessors))
lambda x: x(
matcher=matcher,
exception=exception,
bot=bot,
event=event,
state=state,
_stack=stack,
_dependency_cache=dependency_cache,
),
_run_postprocessors,
)
)
if coros:
try:
await asyncio.gather(*coros)
@ -217,12 +244,16 @@ async def handle_event(bot: "Bot", event: "Event") -> None:
async with AsyncExitStack() as stack:
coros = list(
map(
lambda x: x(bot=bot,
event=event,
state=state,
_stack=stack,
_dependency_cache=dependency_cache),
_event_preprocessors))
lambda x: x(
bot=bot,
event=event,
state=state,
_stack=stack,
_dependency_cache=dependency_cache,
),
_event_preprocessors,
)
)
if coros:
try:
if show_log:
@ -236,7 +267,8 @@ async def handle_event(bot: "Bot", event: "Event") -> None:
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running EventPreProcessors. "
"Event ignored!</bg #f8bbd0></r>")
"Event ignored!</bg #f8bbd0></r>"
)
return
# Trie Match
@ -251,13 +283,13 @@ async def handle_event(bot: "Bot", event: "Event") -> None:
logger.debug(f"Checking for matchers in priority {priority}...")
pending_tasks = [
_check_matcher(priority, matcher, bot, event, state.copy(),
stack, dependency_cache)
_check_matcher(
priority, matcher, bot, event, state.copy(), stack, dependency_cache
)
for matcher in matchers[priority]
]
results = await asyncio.gather(*pending_tasks,
return_exceptions=True)
results = await asyncio.gather(*pending_tasks, return_exceptions=True)
for result in results:
if not isinstance(result, Exception):
@ -272,12 +304,16 @@ async def handle_event(bot: "Bot", event: "Event") -> None:
coros = list(
map(
lambda x: x(bot=bot,
event=event,
state=state,
_stack=stack,
_dependency_cache=dependency_cache),
_event_postprocessors))
lambda x: x(
bot=bot,
event=event,
state=state,
_stack=stack,
_dependency_cache=dependency_cache,
),
_event_postprocessors,
)
)
if coros:
try:
if show_log:

View File

@ -10,69 +10,61 @@ from nonebot.utils import generic_check_issubclass
class BotParam(Param):
@classmethod
def _check(cls, name: str, param: inspect.Parameter) -> bool:
return generic_check_issubclass(
param.annotation, Bot) or (param.annotation == param.empty and
name == "bot")
return generic_check_issubclass(param.annotation, Bot) or (
param.annotation == param.empty and name == "bot"
)
def _solve(self, bot: Bot, **kwargs: Any) -> Any:
return bot
class EventParam(Param):
@classmethod
def _check(cls, name: str, param: inspect.Parameter) -> bool:
return generic_check_issubclass(
param.annotation, Event) or (param.annotation == param.empty and
name == "event")
return generic_check_issubclass(param.annotation, Event) or (
param.annotation == param.empty and name == "event"
)
def _solve(self, event: Event, **kwargs: Any) -> Any:
return event
class StateParam(Param):
@classmethod
def _check(cls, name: str, param: inspect.Parameter) -> bool:
return generic_check_issubclass(
param.annotation, Dict) or (param.annotation == param.empty and
name == "state")
return generic_check_issubclass(param.annotation, Dict) or (
param.annotation == param.empty and name == "state"
)
def _solve(self, state: T_State, **kwargs: Any) -> Any:
return state
class MatcherParam(Param):
@classmethod
def _check(cls, name: str, param: inspect.Parameter) -> bool:
return generic_check_issubclass(
param.annotation, Matcher) or (param.annotation == param.empty and
name == "matcher")
return generic_check_issubclass(param.annotation, Matcher) or (
param.annotation == param.empty and name == "matcher"
)
def _solve(self, matcher: Optional["Matcher"] = None, **kwargs: Any) -> Any:
return matcher
class ExceptionParam(Param):
@classmethod
def _check(cls, name: str, param: inspect.Parameter) -> bool:
return generic_check_issubclass(
param.annotation, Exception) or (param.annotation == param.empty and
name == "exception")
return generic_check_issubclass(param.annotation, Exception) or (
param.annotation == param.empty and name == "exception"
)
def _solve(self,
exception: Optional[Exception] = None,
**kwargs: Any) -> Any:
def _solve(self, exception: Optional[Exception] = None, **kwargs: Any) -> Any:
return exception
class DefaultParam(Param):
@classmethod
def _check(cls, name: str, param: inspect.Parameter) -> bool:
return param.default != param.empty

View File

@ -34,11 +34,10 @@ class Permission:
from nonebot.utils import run_sync
Permission(async_function, run_sync(sync_function))
"""
__slots__ = ("checkers",)
HANDLER_PARAM_TYPES = [
params.BotParam, params.EventParam, params.DefaultParam
]
HANDLER_PARAM_TYPES = [params.BotParam, params.EventParam, params.DefaultParam]
def __init__(self, *checkers: Union[T_PermissionChecker, Handler]) -> None:
"""
@ -48,9 +47,11 @@ class Permission:
"""
self.checkers = set(
checker if isinstance(checker, Handler) else Handler(
checker, allow_types=self.HANDLER_PARAM_TYPES)
for checker in checkers)
checker
if isinstance(checker, Handler)
else Handler(checker, allow_types=self.HANDLER_PARAM_TYPES)
for checker in checkers
)
"""
:说明:
@ -66,8 +67,8 @@ class Permission:
bot: Bot,
event: Event,
stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[Dict[Callable[..., Any],
Any]] = None) -> bool:
dependency_cache: Optional[Dict[Callable[..., Any], Any]] = None,
) -> bool:
"""
:说明:
@ -87,19 +88,24 @@ class Permission:
if not self.checkers:
return True
results = await asyncio.gather(
*(checker(bot=bot,
event=event,
_stack=stack,
_dependency_cache=dependency_cache)
for checker in self.checkers))
*(
checker(
bot=bot,
event=event,
_stack=stack,
_dependency_cache=dependency_cache,
)
for checker in self.checkers
)
)
return any(results)
def __and__(self, other) -> NoReturn:
raise RuntimeError("And operation between Permissions is not allowed.")
def __or__(
self, other: Optional[Union["Permission",
T_PermissionChecker]]) -> "Permission":
self, other: Optional[Union["Permission", T_PermissionChecker]]
) -> "Permission":
if other is None:
return self
elif isinstance(other, Permission):
@ -155,15 +161,17 @@ def USER(*user: str, perm: Optional[Permission] = None):
"""
async def _user(bot: Bot, event: Event) -> bool:
return bool(event.get_session_id() in user and
(perm is None or await perm(bot, event)))
return bool(
event.get_session_id() in user and (perm is None or await perm(bot, event))
)
return Permission(_user)
async def _superuser(bot: Bot, event: Event) -> bool:
return (event.get_type() == "message" and
event.get_user_id() in bot.config.superusers)
return (
event.get_type() == "message" and event.get_user_id() in bot.config.superusers
)
SUPERUSER = Permission(_superuser)

View File

@ -9,8 +9,9 @@ from typing import List, Optional
from contextvars import ContextVar
_managers: List["PluginManager"] = []
_current_plugin: ContextVar[Optional["Plugin"]] = ContextVar("_current_plugin",
default=None)
_current_plugin: ContextVar[Optional["Plugin"]] = ContextVar(
"_current_plugin", default=None
)
from .on import on as on
from .manager import PluginManager

View File

@ -33,8 +33,7 @@ class Export(dict):
return func
def __setitem__(self, key, value):
super().__setitem__(key,
Export(value) if isinstance(value, dict) else value)
super().__setitem__(key, Export(value) if isinstance(value, dict) else value)
def __setattr__(self, name, value):
self[name] = Export(value) if isinstance(value, dict) else value

View File

@ -49,8 +49,9 @@ def load_plugins(*plugin_dir: str) -> Set[Plugin]:
return manager.load_all_plugins()
def load_all_plugins(module_path: Iterable[str],
plugin_dir: Iterable[str]) -> Set[Plugin]:
def load_all_plugins(
module_path: Iterable[str], plugin_dir: Iterable[str]
) -> Set[Plugin]:
"""
:说明:
@ -90,8 +91,7 @@ def load_from_json(file_path: str, encoding: str = "utf-8") -> Set[Plugin]:
plugins = data.get("plugins")
plugin_dirs = data.get("plugin_dirs")
assert isinstance(plugins, list), "plugins must be a list of plugin name"
assert isinstance(plugin_dirs,
list), "plugin_dirs must be a list of directories"
assert isinstance(plugin_dirs, list), "plugin_dirs must be a list of directories"
return load_all_plugins(set(plugins), set(plugin_dirs))
@ -120,14 +120,14 @@ def load_from_toml(file_path: str, encoding: str = "utf-8") -> Set[Plugin]:
if nonebot_data:
warnings.warn(
"[nonebot.plugins] table are now deprecated. Use [tool.nonebot] instead.",
DeprecationWarning)
DeprecationWarning,
)
else:
raise ValueError("Cannot find '[tool.nonebot]' in given toml file!")
plugins = nonebot_data.get("plugins", [])
plugin_dirs = nonebot_data.get("plugin_dirs", [])
assert isinstance(plugins, list), "plugins must be a list of plugin name"
assert isinstance(plugin_dirs,
list), "plugin_dirs must be a list of directories"
assert isinstance(plugin_dirs, list), "plugin_dirs must be a list of directories"
return load_all_plugins(plugins, plugin_dirs)
@ -163,5 +163,5 @@ def require(name: str) -> Export:
"""
plugin = get_plugin(name) or load_plugin(name)
if not plugin:
raise RuntimeError(f"Cannot load plugin \"{name}\"!")
raise RuntimeError(f'Cannot load plugin "{name}"!')
return plugin.export

View File

@ -15,7 +15,6 @@ from . import _managers, _current_plugin
class PluginManager:
def __init__(
self,
plugins: Optional[Iterable[str]] = None,
@ -39,14 +38,15 @@ class PluginManager:
def _previous_plugins(self) -> List[str]:
_pre_managers: List[PluginManager]
if self in _managers:
_pre_managers = _managers[:_managers.index(self)]
_pre_managers = _managers[: _managers.index(self)]
else:
_pre_managers = _managers[:]
return [
*chain.from_iterable(
[*manager.plugins, *manager.searched_plugins.keys()]
for manager in _pre_managers)
for manager in _pre_managers
)
]
def list_plugins(self) -> Set[str]:
@ -57,13 +57,14 @@ class PluginManager:
for module_info in pkgutil.iter_modules(self.search_path):
if module_info.name.startswith("_"):
continue
if module_info.name in searched_plugins.keys(
) or module_info.name in previous_plugins:
if (
module_info.name in searched_plugins.keys()
or module_info.name in previous_plugins
):
raise RuntimeError(
f"Plugin already exists: {module_info.name}! Check your plugin name"
)
module_spec = module_info.module_finder.find_spec(
module_info.name, None)
module_spec = module_info.module_finder.find_spec(module_info.name, None)
if not module_spec:
continue
module_path = module_spec.origin
@ -80,14 +81,15 @@ class PluginManager:
if name in self.plugins:
module = importlib.import_module(name)
elif name not in self.searched_plugins:
raise RuntimeError(
f"Plugin not found: {name}! Check your plugin name")
raise RuntimeError(f"Plugin not found: {name}! Check your plugin name")
else:
module = importlib.import_module(
self._path_to_module_name(self.searched_plugins[name]))
self._path_to_module_name(self.searched_plugins[name])
)
logger.opt(colors=True).success(
f'Succeeded to import "<y>{escape_tag(name)}</y>"')
f'Succeeded to import "<y>{escape_tag(name)}</y>"'
)
return getattr(module, "__plugin__", None)
except Exception as e:
logger.opt(colors=True, exception=e).error(
@ -96,16 +98,17 @@ class PluginManager:
def load_all_plugins(self) -> Set[Plugin]:
return set(
filter(None,
(self.load_plugin(name) for name in self.list_plugins())))
filter(None, (self.load_plugin(name) for name in self.list_plugins()))
)
class PluginFinder(MetaPathFinder):
def find_spec(self,
fullname: str,
path: Optional[Sequence[Union[bytes, str]]],
target: Optional[ModuleType] = None):
def find_spec(
self,
fullname: str,
path: Optional[Sequence[Union[bytes, str]]],
target: Optional[ModuleType] = None,
):
if _managers:
index = -1
module_spec = PathFinder.find_spec(fullname, path, target)
@ -119,10 +122,11 @@ class PluginFinder(MetaPathFinder):
while -index <= len(_managers):
manager = _managers[index]
if fullname in manager.plugins or module_path in manager.searched_plugins.values(
if (
fullname in manager.plugins
or module_path in manager.searched_plugins.values()
):
module_spec.loader = PluginLoader(manager, fullname,
module_origin)
module_spec.loader = PluginLoader(manager, fullname, module_origin)
return module_spec
index -= 1
@ -130,7 +134,6 @@ class PluginFinder(MetaPathFinder):
class PluginLoader(SourceFileLoader):
def __init__(self, manager: PluginManager, fullname: str, path) -> None:
self.manager = manager
self.loaded = False

View File

@ -10,8 +10,18 @@ from nonebot.matcher import Matcher
from .manager import _current_plugin
from nonebot.permission import Permission
from nonebot.typing import T_State, T_Handler, T_RuleChecker, T_StateFactory
from nonebot.rule import (PREFIX_KEY, RAW_CMD_KEY, Rule, ArgumentParser, regex,
command, keyword, endswith, startswith, shell_command)
from nonebot.rule import (
PREFIX_KEY,
RAW_CMD_KEY,
Rule,
ArgumentParser,
regex,
command,
keyword,
endswith,
startswith,
shell_command,
)
def _store_matcher(matcher: Type[Matcher]) -> None:
@ -30,17 +40,19 @@ def _get_matcher_module(depth: int = 1) -> Optional[ModuleType]:
return sys.modules.get(module_name)
def on(type: str = "",
rule: Optional[Union[Rule, T_RuleChecker]] = None,
permission: Optional[Permission] = None,
*,
handlers: Optional[List[Union[T_Handler, Handler]]] = None,
temp: bool = False,
priority: int = 1,
block: bool = False,
state: Optional[T_State] = None,
state_factory: Optional[T_StateFactory] = None,
_depth: int = 0) -> Type[Matcher]:
def on(
type: str = "",
rule: Optional[Union[Rule, T_RuleChecker]] = None,
permission: Optional[Permission] = None,
*,
handlers: Optional[List[Union[T_Handler, Handler]]] = None,
temp: bool = False,
priority: int = 1,
block: bool = False,
state: Optional[T_State] = None,
state_factory: Optional[T_StateFactory] = None,
_depth: int = 0,
) -> Type[Matcher]:
"""
:说明:
@ -62,30 +74,34 @@ def on(type: str = "",
- ``Type[Matcher]``
"""
matcher = Matcher.new(type,
Rule() & rule,
permission or Permission(),
temp=temp,
priority=priority,
block=block,
handlers=handlers,
plugin=_current_plugin.get(),
module=_get_matcher_module(_depth + 1),
default_state=state,
default_state_factory=state_factory)
matcher = Matcher.new(
type,
Rule() & rule,
permission or Permission(),
temp=temp,
priority=priority,
block=block,
handlers=handlers,
plugin=_current_plugin.get(),
module=_get_matcher_module(_depth + 1),
default_state=state,
default_state_factory=state_factory,
)
_store_matcher(matcher)
return matcher
def on_metaevent(rule: Optional[Union[Rule, T_RuleChecker]] = None,
*,
handlers: Optional[List[Union[T_Handler, Handler]]] = None,
temp: bool = False,
priority: int = 1,
block: bool = False,
state: Optional[T_State] = None,
state_factory: Optional[T_StateFactory] = None,
_depth: int = 0) -> Type[Matcher]:
def on_metaevent(
rule: Optional[Union[Rule, T_RuleChecker]] = None,
*,
handlers: Optional[List[Union[T_Handler, Handler]]] = None,
temp: bool = False,
priority: int = 1,
block: bool = False,
state: Optional[T_State] = None,
state_factory: Optional[T_StateFactory] = None,
_depth: int = 0,
) -> Type[Matcher]:
"""
:说明:
@ -105,31 +121,35 @@ def on_metaevent(rule: Optional[Union[Rule, T_RuleChecker]] = None,
- ``Type[Matcher]``
"""
matcher = Matcher.new("meta_event",
Rule() & rule,
Permission(),
temp=temp,
priority=priority,
block=block,
handlers=handlers,
plugin=_current_plugin.get(),
module=_get_matcher_module(_depth + 1),
default_state=state,
default_state_factory=state_factory)
matcher = Matcher.new(
"meta_event",
Rule() & rule,
Permission(),
temp=temp,
priority=priority,
block=block,
handlers=handlers,
plugin=_current_plugin.get(),
module=_get_matcher_module(_depth + 1),
default_state=state,
default_state_factory=state_factory,
)
_store_matcher(matcher)
return matcher
def on_message(rule: Optional[Union[Rule, T_RuleChecker]] = None,
permission: Optional[Permission] = None,
*,
handlers: Optional[List[Union[T_Handler, Handler]]] = None,
temp: bool = False,
priority: int = 1,
block: bool = True,
state: Optional[T_State] = None,
state_factory: Optional[T_StateFactory] = None,
_depth: int = 0) -> Type[Matcher]:
def on_message(
rule: Optional[Union[Rule, T_RuleChecker]] = None,
permission: Optional[Permission] = None,
*,
handlers: Optional[List[Union[T_Handler, Handler]]] = None,
temp: bool = False,
priority: int = 1,
block: bool = True,
state: Optional[T_State] = None,
state_factory: Optional[T_StateFactory] = None,
_depth: int = 0,
) -> Type[Matcher]:
"""
:说明:
@ -150,30 +170,34 @@ def on_message(rule: Optional[Union[Rule, T_RuleChecker]] = None,
- ``Type[Matcher]``
"""
matcher = Matcher.new("message",
Rule() & rule,
permission or Permission(),
temp=temp,
priority=priority,
block=block,
handlers=handlers,
plugin=_current_plugin.get(),
module=_get_matcher_module(_depth + 1),
default_state=state,
default_state_factory=state_factory)
matcher = Matcher.new(
"message",
Rule() & rule,
permission or Permission(),
temp=temp,
priority=priority,
block=block,
handlers=handlers,
plugin=_current_plugin.get(),
module=_get_matcher_module(_depth + 1),
default_state=state,
default_state_factory=state_factory,
)
_store_matcher(matcher)
return matcher
def on_notice(rule: Optional[Union[Rule, T_RuleChecker]] = None,
*,
handlers: Optional[List[Union[T_Handler, Handler]]] = None,
temp: bool = False,
priority: int = 1,
block: bool = False,
state: Optional[T_State] = None,
state_factory: Optional[T_StateFactory] = None,
_depth: int = 0) -> Type[Matcher]:
def on_notice(
rule: Optional[Union[Rule, T_RuleChecker]] = None,
*,
handlers: Optional[List[Union[T_Handler, Handler]]] = None,
temp: bool = False,
priority: int = 1,
block: bool = False,
state: Optional[T_State] = None,
state_factory: Optional[T_StateFactory] = None,
_depth: int = 0,
) -> Type[Matcher]:
"""
:说明:
@ -193,30 +217,34 @@ def on_notice(rule: Optional[Union[Rule, T_RuleChecker]] = None,
- ``Type[Matcher]``
"""
matcher = Matcher.new("notice",
Rule() & rule,
Permission(),
temp=temp,
priority=priority,
block=block,
handlers=handlers,
plugin=_current_plugin.get(),
module=_get_matcher_module(_depth + 1),
default_state=state,
default_state_factory=state_factory)
matcher = Matcher.new(
"notice",
Rule() & rule,
Permission(),
temp=temp,
priority=priority,
block=block,
handlers=handlers,
plugin=_current_plugin.get(),
module=_get_matcher_module(_depth + 1),
default_state=state,
default_state_factory=state_factory,
)
_store_matcher(matcher)
return matcher
def on_request(rule: Optional[Union[Rule, T_RuleChecker]] = None,
*,
handlers: Optional[List[Union[T_Handler, Handler]]] = None,
temp: bool = False,
priority: int = 1,
block: bool = False,
state: Optional[T_State] = None,
state_factory: Optional[T_StateFactory] = None,
_depth: int = 0) -> Type[Matcher]:
def on_request(
rule: Optional[Union[Rule, T_RuleChecker]] = None,
*,
handlers: Optional[List[Union[T_Handler, Handler]]] = None,
temp: bool = False,
priority: int = 1,
block: bool = False,
state: Optional[T_State] = None,
state_factory: Optional[T_StateFactory] = None,
_depth: int = 0,
) -> Type[Matcher]:
"""
:说明:
@ -236,26 +264,30 @@ def on_request(rule: Optional[Union[Rule, T_RuleChecker]] = None,
- ``Type[Matcher]``
"""
matcher = Matcher.new("request",
Rule() & rule,
Permission(),
temp=temp,
priority=priority,
block=block,
handlers=handlers,
plugin=_current_plugin.get(),
module=_get_matcher_module(_depth + 1),
default_state=state,
default_state_factory=state_factory)
matcher = Matcher.new(
"request",
Rule() & rule,
Permission(),
temp=temp,
priority=priority,
block=block,
handlers=handlers,
plugin=_current_plugin.get(),
module=_get_matcher_module(_depth + 1),
default_state=state,
default_state_factory=state_factory,
)
_store_matcher(matcher)
return matcher
def on_startswith(msg: Union[str, Tuple[str, ...]],
rule: Optional[Optional[Union[Rule, T_RuleChecker]]] = None,
ignorecase: bool = False,
_depth: int = 0,
**kwargs) -> Type[Matcher]:
def on_startswith(
msg: Union[str, Tuple[str, ...]],
rule: Optional[Optional[Union[Rule, T_RuleChecker]]] = None,
ignorecase: bool = False,
_depth: int = 0,
**kwargs,
) -> Type[Matcher]:
"""
:说明:
@ -278,16 +310,16 @@ def on_startswith(msg: Union[str, Tuple[str, ...]],
- ``Type[Matcher]``
"""
return on_message(startswith(msg, ignorecase) & rule,
**kwargs,
_depth=_depth + 1)
return on_message(startswith(msg, ignorecase) & rule, **kwargs, _depth=_depth + 1)
def on_endswith(msg: Union[str, Tuple[str, ...]],
rule: Optional[Optional[Union[Rule, T_RuleChecker]]] = None,
ignorecase: bool = False,
_depth: int = 0,
**kwargs) -> Type[Matcher]:
def on_endswith(
msg: Union[str, Tuple[str, ...]],
rule: Optional[Optional[Union[Rule, T_RuleChecker]]] = None,
ignorecase: bool = False,
_depth: int = 0,
**kwargs,
) -> Type[Matcher]:
"""
:说明:
@ -310,15 +342,15 @@ def on_endswith(msg: Union[str, Tuple[str, ...]],
- ``Type[Matcher]``
"""
return on_message(endswith(msg, ignorecase) & rule,
**kwargs,
_depth=_depth + 1)
return on_message(endswith(msg, ignorecase) & rule, **kwargs, _depth=_depth + 1)
def on_keyword(keywords: Set[str],
rule: Optional[Union[Rule, T_RuleChecker]] = None,
_depth: int = 0,
**kwargs) -> Type[Matcher]:
def on_keyword(
keywords: Set[str],
rule: Optional[Union[Rule, T_RuleChecker]] = None,
_depth: int = 0,
**kwargs,
) -> Type[Matcher]:
"""
:说明:
@ -343,11 +375,13 @@ def on_keyword(keywords: Set[str],
return on_message(keyword(*keywords) & rule, **kwargs, _depth=_depth + 1)
def on_command(cmd: Union[str, Tuple[str, ...]],
rule: Optional[Union[Rule, T_RuleChecker]] = None,
aliases: Optional[Set[Union[str, Tuple[str, ...]]]] = None,
_depth: int = 0,
**kwargs) -> Type[Matcher]:
def on_command(
cmd: Union[str, Tuple[str, ...]],
rule: Optional[Union[Rule, T_RuleChecker]] = None,
aliases: Optional[Set[Union[str, Tuple[str, ...]]]] = None,
_depth: int = 0,
**kwargs,
) -> Type[Matcher]:
"""
:说明:
@ -382,7 +416,8 @@ def on_command(cmd: Union[str, Tuple[str, ...]],
if not segment_text.startswith(state[PREFIX_KEY][RAW_CMD_KEY]):
return
new_message = message.__class__(
segment_text[len(state[PREFIX_KEY][RAW_CMD_KEY]):].lstrip())
segment_text[len(state[PREFIX_KEY][RAW_CMD_KEY]) :].lstrip()
)
for new_segment in reversed(new_message):
message.insert(0, new_segment)
@ -390,18 +425,19 @@ def on_command(cmd: Union[str, Tuple[str, ...]],
handlers.insert(0, _strip_cmd)
commands = set([cmd]) | (aliases or set())
return on_message(command(*commands) & rule,
handlers=handlers,
**kwargs,
_depth=_depth + 1)
return on_message(
command(*commands) & rule, handlers=handlers, **kwargs, _depth=_depth + 1
)
def on_shell_command(cmd: Union[str, Tuple[str, ...]],
rule: Optional[Union[Rule, T_RuleChecker]] = None,
aliases: Optional[Set[Union[str, Tuple[str, ...]]]] = None,
parser: Optional[ArgumentParser] = None,
_depth: int = 0,
**kwargs) -> Type[Matcher]:
def on_shell_command(
cmd: Union[str, Tuple[str, ...]],
rule: Optional[Union[Rule, T_RuleChecker]] = None,
aliases: Optional[Set[Union[str, Tuple[str, ...]]]] = None,
parser: Optional[ArgumentParser] = None,
_depth: int = 0,
**kwargs,
) -> Type[Matcher]:
"""
:说明:
@ -434,7 +470,8 @@ def on_shell_command(cmd: Union[str, Tuple[str, ...]],
message = event.get_message()
segment = message.pop(0)
new_message = message.__class__(
str(segment)[len(state[PREFIX_KEY][RAW_CMD_KEY]):].strip())
str(segment)[len(state[PREFIX_KEY][RAW_CMD_KEY]) :].strip()
)
for new_segment in reversed(new_message):
message.insert(0, new_segment)
@ -442,17 +479,21 @@ def on_shell_command(cmd: Union[str, Tuple[str, ...]],
handlers.insert(0, _strip_cmd)
commands = set([cmd]) | (aliases or set())
return on_message(shell_command(*commands, parser=parser) & rule,
handlers=handlers,
**kwargs,
_depth=_depth + 1)
return on_message(
shell_command(*commands, parser=parser) & rule,
handlers=handlers,
**kwargs,
_depth=_depth + 1,
)
def on_regex(pattern: str,
flags: Union[int, re.RegexFlag] = 0,
rule: Optional[Union[Rule, T_RuleChecker]] = None,
_depth: int = 0,
**kwargs) -> Type[Matcher]:
def on_regex(
pattern: str,
flags: Union[int, re.RegexFlag] = 0,
rule: Optional[Union[Rule, T_RuleChecker]] = None,
_depth: int = 0,
**kwargs,
) -> Type[Matcher]:
"""
:说明:
@ -503,8 +544,7 @@ class CommandGroup:
- **说明**: 其他传递给 ``on_command`` 的参数默认值
"""
def command(self, cmd: Union[str, Tuple[str, ...]],
**kwargs) -> Type[Matcher]:
def command(self, cmd: Union[str, Tuple[str, ...]], **kwargs) -> Type[Matcher]:
"""
:说明:
@ -526,8 +566,9 @@ class CommandGroup:
final_kwargs.update(kwargs)
return on_command(cmd, **final_kwargs, _depth=1)
def shell_command(self, cmd: Union[str, Tuple[str, ...]],
**kwargs) -> Type[Matcher]:
def shell_command(
self, cmd: Union[str, Tuple[str, ...]], **kwargs
) -> Type[Matcher]:
"""
:说明:
@ -708,8 +749,9 @@ class MatcherGroup:
self.matchers.append(matcher)
return matcher
def on_startswith(self, msg: Union[str, Tuple[str, ...]],
**kwargs) -> Type[Matcher]:
def on_startswith(
self, msg: Union[str, Tuple[str, ...]], **kwargs
) -> Type[Matcher]:
"""
:说明:
@ -739,8 +781,7 @@ class MatcherGroup:
self.matchers.append(matcher)
return matcher
def on_endswith(self, msg: Union[str, Tuple[str, ...]],
**kwargs) -> Type[Matcher]:
def on_endswith(self, msg: Union[str, Tuple[str, ...]], **kwargs) -> Type[Matcher]:
"""
:说明:
@ -799,10 +840,12 @@ class MatcherGroup:
self.matchers.append(matcher)
return matcher
def on_command(self,
cmd: Union[str, Tuple[str, ...]],
aliases: Optional[Set[Union[str, Tuple[str, ...]]]] = None,
**kwargs) -> Type[Matcher]:
def on_command(
self,
cmd: Union[str, Tuple[str, ...]],
aliases: Optional[Set[Union[str, Tuple[str, ...]]]] = None,
**kwargs,
) -> Type[Matcher]:
"""
:说明:
@ -834,12 +877,13 @@ class MatcherGroup:
self.matchers.append(matcher)
return matcher
def on_shell_command(self,
cmd: Union[str, Tuple[str, ...]],
aliases: Optional[Set[Union[str, Tuple[str,
...]]]] = None,
parser: Optional[ArgumentParser] = None,
**kwargs) -> Type[Matcher]:
def on_shell_command(
self,
cmd: Union[str, Tuple[str, ...]],
aliases: Optional[Set[Union[str, Tuple[str, ...]]]] = None,
parser: Optional[ArgumentParser] = None,
**kwargs,
) -> Type[Matcher]:
"""
:说明:
@ -870,18 +914,15 @@ class MatcherGroup:
final_kwargs = self.base_kwargs.copy()
final_kwargs.update(kwargs)
final_kwargs.pop("type", None)
matcher = on_shell_command(cmd,
aliases=aliases,
parser=parser,
**final_kwargs,
_depth=1)
matcher = on_shell_command(
cmd, aliases=aliases, parser=parser, **final_kwargs, _depth=1
)
self.matchers.append(matcher)
return matcher
def on_regex(self,
pattern: str,
flags: Union[int, re.RegexFlag] = 0,
**kwargs) -> Type[Matcher]:
def on_regex(
self, pattern: str, flags: Union[int, re.RegexFlag] = 0, **kwargs
) -> Type[Matcher]:
"""
:说明:

View File

@ -7,361 +7,335 @@ from nonebot.permission import Permission
from nonebot.rule import Rule, ArgumentParser
from nonebot.typing import T_State, T_Handler, T_RuleChecker, T_StateFactory
def on(type: str = "",
rule: Optional[Union[Rule, T_RuleChecker]] = ...,
permission: Optional[Permission] = ...,
*,
handlers: Optional[List[Union[T_Handler, Handler]]] = ...,
temp: bool = ...,
priority: int = ...,
block: bool = ...,
state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...) -> Type[Matcher]:
...
def on(
type: str = "",
rule: Optional[Union[Rule, T_RuleChecker]] = ...,
permission: Optional[Permission] = ...,
*,
handlers: Optional[List[Union[T_Handler, Handler]]] = ...,
temp: bool = ...,
priority: int = ...,
block: bool = ...,
state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...,
) -> Type[Matcher]: ...
def on_metaevent(
rule: Optional[Union[Rule, T_RuleChecker]] = ...,
*,
handlers: Optional[List[Union[T_Handler, Handler]]] = ...,
temp: bool = ...,
priority: int = ...,
block: bool = ...,
state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...) -> Type[Matcher]:
...
def on_message(rule: Optional[Union[Rule, T_RuleChecker]] = ...,
permission: Optional[Permission] = ...,
*,
handlers: Optional[List[Union[T_Handler, Handler]]] = ...,
temp: bool = ...,
priority: int = ...,
block: bool = ...,
state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...) -> Type[Matcher]:
...
def on_notice(rule: Optional[Union[Rule, T_RuleChecker]] = ...,
*,
handlers: Optional[List[Union[T_Handler, Handler]]] = ...,
temp: bool = ...,
priority: int = ...,
block: bool = ...,
state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...) -> Type[Matcher]:
...
def on_request(rule: Optional[Union[Rule, T_RuleChecker]] = ...,
*,
handlers: Optional[List[Union[T_Handler, Handler]]] = ...,
temp: bool = ...,
priority: int = ...,
block: bool = ...,
state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...) -> Type[Matcher]:
...
rule: Optional[Union[Rule, T_RuleChecker]] = ...,
*,
handlers: Optional[List[Union[T_Handler, Handler]]] = ...,
temp: bool = ...,
priority: int = ...,
block: bool = ...,
state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...,
) -> Type[Matcher]: ...
def on_message(
rule: Optional[Union[Rule, T_RuleChecker]] = ...,
permission: Optional[Permission] = ...,
*,
handlers: Optional[List[Union[T_Handler, Handler]]] = ...,
temp: bool = ...,
priority: int = ...,
block: bool = ...,
state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...,
) -> Type[Matcher]: ...
def on_notice(
rule: Optional[Union[Rule, T_RuleChecker]] = ...,
*,
handlers: Optional[List[Union[T_Handler, Handler]]] = ...,
temp: bool = ...,
priority: int = ...,
block: bool = ...,
state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...,
) -> Type[Matcher]: ...
def on_request(
rule: Optional[Union[Rule, T_RuleChecker]] = ...,
*,
handlers: Optional[List[Union[T_Handler, Handler]]] = ...,
temp: bool = ...,
priority: int = ...,
block: bool = ...,
state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...,
) -> Type[Matcher]: ...
def on_startswith(
msg: Union[str, Tuple[str, ...]],
rule: Optional[Optional[Union[Rule, T_RuleChecker]]] = ...,
ignorecase: bool = ...,
msg: Union[str, Tuple[str, ...]],
rule: Optional[Optional[Union[Rule, T_RuleChecker]]] = ...,
ignorecase: bool = ...,
*,
permission: Optional[Permission] = ...,
handlers: Optional[List[Union[T_Handler, Handler]]] = ...,
temp: bool = ...,
priority: int = ...,
block: bool = ...,
state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...,
) -> Type[Matcher]: ...
def on_endswith(
msg: Union[str, Tuple[str, ...]],
rule: Optional[Optional[Union[Rule, T_RuleChecker]]] = ...,
ignorecase: bool = ...,
*,
permission: Optional[Permission] = ...,
handlers: Optional[List[Union[T_Handler, Handler]]] = ...,
temp: bool = ...,
priority: int = ...,
block: bool = ...,
state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...,
) -> Type[Matcher]: ...
def on_keyword(
keywords: Set[str],
rule: Optional[Union[Rule, T_RuleChecker]] = ...,
*,
permission: Optional[Permission] = ...,
handlers: Optional[List[Union[T_Handler, Handler]]] = ...,
temp: bool = ...,
priority: int = ...,
block: bool = ...,
state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...,
) -> Type[Matcher]: ...
def on_command(
cmd: Union[str, Tuple[str, ...]],
rule: Optional[Union[Rule, T_RuleChecker]] = ...,
aliases: Optional[Set[Union[str, Tuple[str, ...]]]] = ...,
*,
permission: Optional[Permission] = ...,
handlers: Optional[List[Union[T_Handler, Handler]]] = ...,
temp: bool = ...,
priority: int = ...,
block: bool = ...,
state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...,
) -> Type[Matcher]: ...
def on_shell_command(
cmd: Union[str, Tuple[str, ...]],
rule: Optional[Union[Rule, T_RuleChecker]] = ...,
aliases: Optional[Set[Union[str, Tuple[str, ...]]]] = ...,
parser: Optional[ArgumentParser] = ...,
*,
permission: Optional[Permission] = ...,
handlers: Optional[List[Union[T_Handler, Handler]]] = ...,
temp: bool = ...,
priority: int = ...,
block: bool = ...,
state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...,
) -> Type[Matcher]: ...
def on_regex(
pattern: str,
flags: Union[int, re.RegexFlag] = ...,
rule: Optional[Union[Rule, T_RuleChecker]] = ...,
*,
permission: Optional[Permission] = ...,
handlers: Optional[List[Union[T_Handler, Handler]]] = ...,
temp: bool = ...,
priority: int = ...,
block: bool = ...,
state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...,
) -> Type[Matcher]: ...
class CommandGroup:
def __init__(
self,
cmd: Union[str, Tuple[str, ...]],
*,
rule: Optional[Union[Rule, T_RuleChecker]] = ...,
permission: Optional[Permission] = ...,
handlers: Optional[List[Union[T_Handler, Handler]]] = ...,
temp: bool = ...,
priority: int = ...,
block: bool = ...,
state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...) -> Type[Matcher]:
...
def on_endswith(msg: Union[str, Tuple[str, ...]],
rule: Optional[Optional[Union[Rule, T_RuleChecker]]] = ...,
ignorecase: bool = ...,
*,
permission: Optional[Permission] = ...,
handlers: Optional[List[Union[T_Handler, Handler]]] = ...,
temp: bool = ...,
priority: int = ...,
block: bool = ...,
state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...) -> Type[Matcher]:
...
def on_keyword(keywords: Set[str],
rule: Optional[Union[Rule, T_RuleChecker]] = ...,
*,
permission: Optional[Permission] = ...,
handlers: Optional[List[Union[T_Handler, Handler]]] = ...,
temp: bool = ...,
priority: int = ...,
block: bool = ...,
state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...) -> Type[Matcher]:
...
def on_command(cmd: Union[str, Tuple[str, ...]],
rule: Optional[Union[Rule, T_RuleChecker]] = ...,
aliases: Optional[Set[Union[str, Tuple[str, ...]]]] = ...,
*,
permission: Optional[Permission] = ...,
handlers: Optional[List[Union[T_Handler, Handler]]] = ...,
temp: bool = ...,
priority: int = ...,
block: bool = ...,
state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...) -> Type[Matcher]:
...
def on_shell_command(
state_factory: Optional[T_StateFactory] = ...,
): ...
def command(
self,
cmd: Union[str, Tuple[str, ...]],
*,
aliases: Optional[Set[Union[str, Tuple[str, ...]]]],
rule: Optional[Union[Rule, T_RuleChecker]] = ...,
permission: Optional[Permission] = ...,
handlers: Optional[List[Union[T_Handler, Handler]]] = ...,
temp: bool = ...,
priority: int = ...,
block: bool = ...,
state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...,
) -> Type[Matcher]: ...
def shell_command(
self,
cmd: Union[str, Tuple[str, ...]],
*,
rule: Optional[Union[Rule, T_RuleChecker]] = ...,
aliases: Optional[Set[Union[str, Tuple[str, ...]]]],
parser: Optional[ArgumentParser] = ...,
permission: Optional[Permission] = ...,
handlers: Optional[List[Union[T_Handler, Handler]]] = ...,
temp: bool = ...,
priority: int = ...,
block: bool = ...,
state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...,
) -> Type[Matcher]: ...
class MatcherGroup:
def __init__(
self,
*,
type: str = ...,
rule: Optional[Union[Rule, T_RuleChecker]] = ...,
permission: Optional[Permission] = ...,
handlers: Optional[List[Union[T_Handler, Handler]]] = ...,
temp: bool = ...,
priority: int = ...,
block: bool = ...,
state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...,
): ...
def on(
self,
*,
type: str = ...,
rule: Optional[Union[Rule, T_RuleChecker]] = ...,
permission: Optional[Permission] = ...,
handlers: Optional[List[Union[T_Handler, Handler]]] = ...,
temp: bool = ...,
priority: int = ...,
block: bool = ...,
state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...,
) -> Type[Matcher]: ...
def on_metaevent(
self,
*,
rule: Optional[Union[Rule, T_RuleChecker]] = ...,
handlers: Optional[List[Union[T_Handler, Handler]]] = ...,
temp: bool = ...,
priority: int = ...,
block: bool = ...,
state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...,
) -> Type[Matcher]: ...
def on_message(
self,
*,
rule: Optional[Union[Rule, T_RuleChecker]] = ...,
permission: Optional[Permission] = ...,
handlers: Optional[List[Union[T_Handler, Handler]]] = ...,
temp: bool = ...,
priority: int = ...,
block: bool = ...,
state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...,
) -> Type[Matcher]: ...
def on_notice(
self,
*,
rule: Optional[Union[Rule, T_RuleChecker]] = ...,
handlers: Optional[List[Union[T_Handler, Handler]]] = ...,
temp: bool = ...,
priority: int = ...,
block: bool = ...,
state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...,
) -> Type[Matcher]: ...
def on_request(
self,
*,
rule: Optional[Union[Rule, T_RuleChecker]] = ...,
handlers: Optional[List[Union[T_Handler, Handler]]] = ...,
temp: bool = ...,
priority: int = ...,
block: bool = ...,
state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...,
) -> Type[Matcher]: ...
def on_startswith(
self,
msg: Union[str, Tuple[str, ...]],
*,
ignorecase: bool = ...,
rule: Optional[Union[Rule, T_RuleChecker]] = ...,
permission: Optional[Permission] = ...,
handlers: Optional[List[Union[T_Handler, Handler]]] = ...,
temp: bool = ...,
priority: int = ...,
block: bool = ...,
state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...,
) -> Type[Matcher]: ...
def on_endswith(
self,
msg: Union[str, Tuple[str, ...]],
*,
ignorecase: bool = ...,
rule: Optional[Union[Rule, T_RuleChecker]] = ...,
permission: Optional[Permission] = ...,
handlers: Optional[List[Union[T_Handler, Handler]]] = ...,
temp: bool = ...,
priority: int = ...,
block: bool = ...,
state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...,
) -> Type[Matcher]: ...
def on_keyword(
self,
keywords: Set[str],
*,
rule: Optional[Union[Rule, T_RuleChecker]] = ...,
permission: Optional[Permission] = ...,
handlers: Optional[List[Union[T_Handler, Handler]]] = ...,
temp: bool = ...,
priority: int = ...,
block: bool = ...,
state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...,
) -> Type[Matcher]: ...
def on_command(
self,
cmd: Union[str, Tuple[str, ...]],
aliases: Optional[Set[Union[str, Tuple[str, ...]]]] = ...,
*,
rule: Optional[Union[Rule, T_RuleChecker]] = ...,
permission: Optional[Permission] = ...,
handlers: Optional[List[Union[T_Handler, Handler]]] = ...,
temp: bool = ...,
priority: int = ...,
block: bool = ...,
state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...,
) -> Type[Matcher]: ...
def on_shell_command(
self,
cmd: Union[str, Tuple[str, ...]],
aliases: Optional[Set[Union[str, Tuple[str, ...]]]] = ...,
parser: Optional[ArgumentParser] = ...,
*,
rule: Optional[Union[Rule, T_RuleChecker]] = ...,
permission: Optional[Permission] = ...,
handlers: Optional[List[Union[T_Handler, Handler]]] = ...,
temp: bool = ...,
priority: int = ...,
block: bool = ...,
state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...) -> Type[Matcher]:
...
def on_regex(pattern: str,
flags: Union[int, re.RegexFlag] = ...,
rule: Optional[Union[Rule, T_RuleChecker]] = ...,
*,
permission: Optional[Permission] = ...,
handlers: Optional[List[Union[T_Handler, Handler]]] = ...,
temp: bool = ...,
priority: int = ...,
block: bool = ...,
state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...) -> Type[Matcher]:
...
class CommandGroup:
def __init__(self,
cmd: Union[str, Tuple[str, ...]],
*,
rule: Optional[Union[Rule, T_RuleChecker]] = ...,
permission: Optional[Permission] = ...,
handlers: Optional[List[Union[T_Handler, Handler]]] = ...,
temp: bool = ...,
priority: int = ...,
block: bool = ...,
state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...):
...
def command(self,
cmd: Union[str, Tuple[str, ...]],
*,
aliases: Optional[Set[Union[str, Tuple[str, ...]]]],
rule: Optional[Union[Rule, T_RuleChecker]] = ...,
permission: Optional[Permission] = ...,
handlers: Optional[List[Union[T_Handler, Handler]]] = ...,
temp: bool = ...,
priority: int = ...,
block: bool = ...,
state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...) -> Type[Matcher]:
...
def shell_command(
self,
cmd: Union[str, Tuple[str, ...]],
*,
rule: Optional[Union[Rule, T_RuleChecker]] = ...,
aliases: Optional[Set[Union[str, Tuple[str, ...]]]],
parser: Optional[ArgumentParser] = ...,
permission: Optional[Permission] = ...,
handlers: Optional[List[Union[T_Handler, Handler]]] = ...,
temp: bool = ...,
priority: int = ...,
block: bool = ...,
state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...) -> Type[Matcher]:
...
class MatcherGroup:
def __init__(self,
*,
type: str = ...,
rule: Optional[Union[Rule, T_RuleChecker]] = ...,
permission: Optional[Permission] = ...,
handlers: Optional[List[Union[T_Handler, Handler]]] = ...,
temp: bool = ...,
priority: int = ...,
block: bool = ...,
state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...):
...
def on(self,
*,
type: str = ...,
rule: Optional[Union[Rule, T_RuleChecker]] = ...,
permission: Optional[Permission] = ...,
handlers: Optional[List[Union[T_Handler, Handler]]] = ...,
temp: bool = ...,
priority: int = ...,
block: bool = ...,
state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...) -> Type[Matcher]:
...
def on_metaevent(
self,
*,
rule: Optional[Union[Rule, T_RuleChecker]] = ...,
handlers: Optional[List[Union[T_Handler, Handler]]] = ...,
temp: bool = ...,
priority: int = ...,
block: bool = ...,
state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...) -> Type[Matcher]:
...
def on_message(
self,
*,
rule: Optional[Union[Rule, T_RuleChecker]] = ...,
permission: Optional[Permission] = ...,
handlers: Optional[List[Union[T_Handler, Handler]]] = ...,
temp: bool = ...,
priority: int = ...,
block: bool = ...,
state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...) -> Type[Matcher]:
...
def on_notice(
self,
*,
rule: Optional[Union[Rule, T_RuleChecker]] = ...,
handlers: Optional[List[Union[T_Handler, Handler]]] = ...,
temp: bool = ...,
priority: int = ...,
block: bool = ...,
state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...) -> Type[Matcher]:
...
def on_request(
self,
*,
rule: Optional[Union[Rule, T_RuleChecker]] = ...,
handlers: Optional[List[Union[T_Handler, Handler]]] = ...,
temp: bool = ...,
priority: int = ...,
block: bool = ...,
state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...) -> Type[Matcher]:
...
def on_startswith(
self,
msg: Union[str, Tuple[str, ...]],
*,
ignorecase: bool = ...,
rule: Optional[Union[Rule, T_RuleChecker]] = ...,
permission: Optional[Permission] = ...,
handlers: Optional[List[Union[T_Handler, Handler]]] = ...,
temp: bool = ...,
priority: int = ...,
block: bool = ...,
state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...) -> Type[Matcher]:
...
def on_endswith(
self,
msg: Union[str, Tuple[str, ...]],
*,
ignorecase: bool = ...,
rule: Optional[Union[Rule, T_RuleChecker]] = ...,
permission: Optional[Permission] = ...,
handlers: Optional[List[Union[T_Handler, Handler]]] = ...,
temp: bool = ...,
priority: int = ...,
block: bool = ...,
state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...) -> Type[Matcher]:
...
def on_keyword(
self,
keywords: Set[str],
*,
rule: Optional[Union[Rule, T_RuleChecker]] = ...,
permission: Optional[Permission] = ...,
handlers: Optional[List[Union[T_Handler, Handler]]] = ...,
temp: bool = ...,
priority: int = ...,
block: bool = ...,
state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...) -> Type[Matcher]:
...
def on_command(
self,
cmd: Union[str, Tuple[str, ...]],
aliases: Optional[Set[Union[str, Tuple[str, ...]]]] = ...,
*,
rule: Optional[Union[Rule, T_RuleChecker]] = ...,
permission: Optional[Permission] = ...,
handlers: Optional[List[Union[T_Handler, Handler]]] = ...,
temp: bool = ...,
priority: int = ...,
block: bool = ...,
state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...) -> Type[Matcher]:
...
def on_shell_command(
self,
cmd: Union[str, Tuple[str, ...]],
aliases: Optional[Set[Union[str, Tuple[str, ...]]]] = ...,
parser: Optional[ArgumentParser] = ...,
*,
rule: Optional[Union[Rule, T_RuleChecker]] = ...,
permission: Optional[Permission] = ...,
handlers: Optional[List[Union[T_Handler, Handler]]] = ...,
temp: bool = ...,
priority: int = ...,
block: bool = ...,
state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...) -> Type[Matcher]:
...
state_factory: Optional[T_StateFactory] = ...,
) -> Type[Matcher]: ...
def on_regex(
self,
pattern: str,
flags: Union[int, re.RegexFlag] = ...,
*,
rule: Optional[Union[Rule, T_RuleChecker]] = ...,
permission: Optional[Permission] = ...,
handlers: Optional[List[Union[T_Handler, Handler]]] = ...,
temp: bool = ...,
priority: int = ...,
block: bool = ...,
state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...) -> Type[Matcher]:
...
self,
pattern: str,
flags: Union[int, re.RegexFlag] = ...,
*,
rule: Optional[Union[Rule, T_RuleChecker]] = ...,
permission: Optional[Permission] = ...,
handlers: Optional[List[Union[T_Handler, Handler]]] = ...,
temp: bool = ...,
priority: int = ...,
block: bool = ...,
state: Optional[T_State] = ...,
state_factory: Optional[T_StateFactory] = ...,
) -> Type[Matcher]: ...

View File

@ -15,6 +15,7 @@ plugins: Dict[str, "Plugin"] = {}
@dataclass(eq=False)
class Plugin(object):
"""存储插件信息"""
name: str
"""
- **类型**: ``str``

View File

@ -3,15 +3,18 @@ from functools import reduce
from nonebot.rule import to_me
from nonebot.plugin import on_command
from nonebot.permission import SUPERUSER
from nonebot.adapters.cqhttp import (Message, MessageEvent, MessageSegment,
unescape)
from nonebot.adapters.cqhttp import (
Message,
MessageEvent,
MessageSegment,
unescape,
)
say = on_command("say", to_me(), permission=SUPERUSER)
@say.handle()
async def say_unescape(event: MessageEvent):
def _unescape(message: Message, segment: MessageSegment):
if segment.is_text():
return message.append(unescape(str(segment)))

View File

@ -1,8 +1,11 @@
from typing import Dict
from nonebot.adapters import Event
from nonebot.message import (IgnoredException, run_preprocessor,
run_postprocessor)
from nonebot.message import (
IgnoredException,
run_preprocessor,
run_postprocessor,
)
_running_matcher: Dict[str, int] = {}

View File

@ -17,8 +17,18 @@ from argparse import Namespace
from contextlib import AsyncExitStack
from typing_extensions import TypedDict
from argparse import ArgumentParser as ArgParser
from typing import (Any, Dict, List, Type, Tuple, Union, Callable, NoReturn,
Optional, Sequence)
from typing import (
Any,
Dict,
List,
Type,
Tuple,
Union,
Callable,
NoReturn,
Optional,
Sequence,
)
from pygtrie import CharTrie
@ -33,10 +43,9 @@ PREFIX_KEY = "_prefix"
SUFFIX_KEY = "_suffix"
CMD_KEY = "command"
RAW_CMD_KEY = "raw_command"
CMD_RESULT = TypedDict("CMD_RESULT", {
"command": Optional[Tuple[str, ...]],
"raw_command": Optional[str]
})
CMD_RESULT = TypedDict(
"CMD_RESULT", {"command": Optional[Tuple[str, ...]], "raw_command": Optional[str]}
)
SHELL_ARGS = "_args"
SHELL_ARGV = "_argv"
@ -61,11 +70,14 @@ class Rule:
from nonebot.utils import run_sync
Rule(async_function, run_sync(sync_function))
"""
__slots__ = ("checkers",)
HANDLER_PARAM_TYPES = [
params.BotParam, params.EventParam, params.StateParam,
params.DefaultParam
params.BotParam,
params.EventParam,
params.StateParam,
params.DefaultParam,
]
def __init__(self, *checkers: Union[T_RuleChecker, Handler]) -> None:
@ -76,9 +88,11 @@ class Rule:
"""
self.checkers = set(
checker if isinstance(checker, Handler) else Handler(
checker, allow_types=self.HANDLER_PARAM_TYPES)
for checker in checkers)
checker
if isinstance(checker, Handler)
else Handler(checker, allow_types=self.HANDLER_PARAM_TYPES)
for checker in checkers
)
"""
:说明:
@ -95,8 +109,8 @@ class Rule:
event: Event,
state: T_State,
stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[Dict[Callable[..., Any],
Any]] = None) -> bool:
dependency_cache: Optional[Dict[Callable[..., Any], Any]] = None,
) -> bool:
"""
:说明:
@ -117,12 +131,17 @@ class Rule:
if not self.checkers:
return True
results = await asyncio.gather(
*(checker(bot=bot,
event=event,
state=state,
_stack=stack,
_dependency_cache=dependency_cache)
for checker in self.checkers))
*(
checker(
bot=bot,
event=event,
state=state,
_stack=stack,
_dependency_cache=dependency_cache,
)
for checker in self.checkers
)
)
return all(results)
def __and__(self, other: Optional[Union["Rule", T_RuleChecker]]) -> "Rule":
@ -156,8 +175,9 @@ class TrieRule:
cls.suffix[suffix[::-1]] = value
@classmethod
def get_value(cls, bot: Bot, event: Event,
state: T_State) -> Tuple[CMD_RESULT, CMD_RESULT]:
def get_value(
cls, bot: Bot, event: Event, state: T_State
) -> Tuple[CMD_RESULT, CMD_RESULT]:
prefix = CMD_RESULT(command=None, raw_command=None)
suffix = CMD_RESULT(command=None, raw_command=None)
state[PREFIX_KEY] = prefix
@ -180,8 +200,7 @@ class TrieRule:
return prefix, suffix
def startswith(msg: Union[str, Tuple[str, ...]],
ignorecase: bool = False) -> Rule:
def startswith(msg: Union[str, Tuple[str, ...]], ignorecase: bool = False) -> Rule:
"""
:说明:
@ -196,7 +215,8 @@ def startswith(msg: Union[str, Tuple[str, ...]],
pattern = re.compile(
f"^(?:{'|'.join(re.escape(prefix) for prefix in msg)})",
re.IGNORECASE if ignorecase else 0)
re.IGNORECASE if ignorecase else 0,
)
async def _startswith(bot: Bot, event: Event, state: T_State) -> bool:
if event.get_type() != "message":
@ -207,8 +227,7 @@ def startswith(msg: Union[str, Tuple[str, ...]],
return Rule(_startswith)
def endswith(msg: Union[str, Tuple[str, ...]],
ignorecase: bool = False) -> Rule:
def endswith(msg: Union[str, Tuple[str, ...]], ignorecase: bool = False) -> Rule:
"""
:说明:
@ -223,7 +242,8 @@ def endswith(msg: Union[str, Tuple[str, ...]],
pattern = re.compile(
f"(?:{'|'.join(re.escape(prefix) for prefix in msg)})$",
re.IGNORECASE if ignorecase else 0)
re.IGNORECASE if ignorecase else 0,
)
async def _endswith(bot: Bot, event: Event, state: T_State) -> bool:
if event.get_type() != "message":
@ -314,19 +334,22 @@ class ArgumentParser(ArgParser):
setattr(self, "message", old_message)
def exit(self, status: int = 0, message: Optional[str] = None):
raise ParserExit(status=status,
message=message or getattr(self, "message", None))
raise ParserExit(
status=status, message=message or getattr(self, "message", None)
)
def parse_args(self,
args: Optional[Sequence[str]] = None,
namespace: Optional[Namespace] = None) -> Namespace:
def parse_args(
self,
args: Optional[Sequence[str]] = None,
namespace: Optional[Namespace] = None,
) -> Namespace:
setattr(self, "message", "")
return super().parse_args(args=args,
namespace=namespace) # type: ignore
return super().parse_args(args=args, namespace=namespace) # type: ignore
def shell_command(*cmds: Union[str, Tuple[str, ...]],
parser: Optional[ArgumentParser] = None) -> Rule:
def shell_command(
*cmds: Union[str, Tuple[str, ...]], parser: Optional[ArgumentParser] = None
) -> Rule:
r"""
:说明:
@ -361,8 +384,7 @@ def shell_command(*cmds: Union[str, Tuple[str, ...]],
\:\:\:
"""
if not isinstance(parser, ArgumentParser):
raise TypeError(
"`parser` must be an instance of nonebot.rule.ArgumentParser")
raise TypeError("`parser` must be an instance of nonebot.rule.ArgumentParser")
config = get_driver().config
command_start = config.command_start
@ -382,8 +404,7 @@ def shell_command(*cmds: Union[str, Tuple[str, ...]],
async def _shell_command(event: Event, state: T_State) -> bool:
if state[PREFIX_KEY][CMD_KEY] in commands:
message = str(event.get_message())
strip_message = message[len(state[PREFIX_KEY][RAW_CMD_KEY]
):].lstrip()
strip_message = message[len(state[PREFIX_KEY][RAW_CMD_KEY]) :].lstrip()
state[SHELL_ARGV] = shlex.split(strip_message)
if parser:
try:

View File

@ -18,8 +18,17 @@
https://docs.python.org/3/library/typing.html
"""
from typing import (TYPE_CHECKING, Any, Dict, Union, TypeVar, Callable,
NoReturn, Optional, Awaitable)
from typing import (
TYPE_CHECKING,
Any,
Dict,
Union,
TypeVar,
Callable,
NoReturn,
Optional,
Awaitable,
)
if TYPE_CHECKING:
from nonebot.adapters import Bot, Event
@ -29,10 +38,8 @@ T_Wrapped = TypeVar("T_Wrapped", bound=Callable)
def overrides(InterfaceClass: object):
def overrider(func: T_Wrapped) -> T_Wrapped:
assert func.__name__ in dir(
InterfaceClass), f"Error method: {func.__name__}"
assert func.__name__ in dir(InterfaceClass), f"Error method: {func.__name__}"
return func
return overrider
@ -80,7 +87,8 @@ T_CallingAPIHook = Callable[["Bot", str, Dict[str, Any]], Awaitable[None]]
``bot.call_api`` 时执行的函数
"""
T_CalledAPIHook = Callable[
["Bot", Optional[Exception], str, Dict[str, Any], Any], Awaitable[None]]
["Bot", Optional[Exception], str, Dict[str, Any], Any], Awaitable[None]
]
"""
:类型: ``Callable[[Bot, Optional[Exception], str, Dict[str, Any], Any], Awaitable[None]]``
@ -193,8 +201,9 @@ T_DependencyCache = Dict[T_Handler, Any]
依赖缓存, 用于存储依赖函数的返回值
"""
T_ArgsParser = Callable[["Bot", "Event", T_State], Union[Awaitable[None],
Awaitable[NoReturn]]]
T_ArgsParser = Callable[
["Bot", "Event", T_State], Union[Awaitable[None], Awaitable[NoReturn]]
]
"""
:类型: ``Callable[[Bot, Event, T_State], Union[Awaitable[None], Awaitable[NoReturn]]]``
@ -210,8 +219,9 @@ T_TypeUpdater = Callable[["Bot", "Event", T_State, str], Awaitable[str]]
TypeUpdater 在 Matcher.pause, Matcher.reject 时被运行,用于更新响应的事件类型。默认会更新为 ``message``。
"""
T_PermissionUpdater = Callable[["Bot", "Event", T_State, "Permission"],
Awaitable["Permission"]]
T_PermissionUpdater = Callable[
["Bot", "Event", T_State, "Permission"], Awaitable["Permission"]
]
"""
:类型: ``Callable[[Bot, Event, T_State, Permission], Awaitable[Permission]]``

View File

@ -8,8 +8,19 @@ from functools import wraps, partial
from contextlib import asynccontextmanager
from typing_extensions import GenericAlias # type: ignore
from typing_extensions import ParamSpec, get_args, get_origin
from typing import (Any, Type, Deque, Tuple, Union, TypeVar, Callable, Optional,
Awaitable, AsyncGenerator, ContextManager)
from typing import (
Any,
Type,
Deque,
Tuple,
Union,
TypeVar,
Callable,
Optional,
Awaitable,
AsyncGenerator,
ContextManager,
)
from nonebot.log import logger
from nonebot.typing import overrides
@ -37,15 +48,16 @@ def escape_tag(s: str) -> str:
def generic_check_issubclass(
cls: Any, class_or_tuple: Union[Type[Any], Tuple[Type[Any],
...]]) -> bool:
cls: Any, class_or_tuple: Union[Type[Any], Tuple[Type[Any], ...]]
) -> bool:
try:
return issubclass(cls, class_or_tuple)
except TypeError:
if get_origin(cls) is Union:
for type_ in get_args(cls):
if type_ is not type(None) and not generic_check_issubclass(
type_, class_or_tuple):
type_, class_or_tuple
):
return False
return True
elif isinstance(cls, GenericAlias):
@ -104,7 +116,8 @@ def run_sync(func: Callable[P, R]) -> Callable[P, Awaitable[R]]:
@asynccontextmanager
async def run_sync_ctx_manager(
cm: ContextManager[T],) -> AsyncGenerator[T, None]:
cm: ContextManager[T],
) -> AsyncGenerator[T, None]:
try:
yield await run_sync(cm.__enter__)()
except Exception as e:
@ -122,7 +135,6 @@ def get_name(obj: Any) -> str:
class CacheLock:
def __init__(self):
self._waiters: Optional[Deque[asyncio.Future]] = None
self._locked = False
@ -144,8 +156,9 @@ class CacheLock:
return self._locked
async def acquire(self):
if (not self._locked and (self._waiters is None or
all(w.cancelled() for w in self._waiters))):
if not self._locked and (
self._waiters is None or all(w.cancelled() for w in self._waiters)
):
self._locked = True
return True
@ -223,6 +236,7 @@ def logger_wrapper(logger_name: str):
def log(level: str, message: str, exception: Optional[Exception] = None):
return logger.opt(colors=True, exception=exception).log(
level, f"<m>{escape_tag(logger_name)}</m> | " + message)
level, f"<m>{escape_tag(logger_name)}</m> | " + message
)
return log