mirror of
https://github.com/nonebot/nonebot2.git
synced 2025-09-06 20:16:47 +00:00
⚡ improve plugin system (#1011)
This commit is contained in:
@ -19,23 +19,52 @@ from typing import Set, Dict, List, Union, Iterable, Optional, Sequence
|
||||
from nonebot.log import logger
|
||||
from nonebot.utils import escape_tag
|
||||
|
||||
from . import _managers, _current_plugin
|
||||
from .plugin import Plugin, _new_plugin, _confirm_plugin
|
||||
from .plugin import Plugin
|
||||
from . import (
|
||||
_managers,
|
||||
_new_plugin,
|
||||
_revert_plugin,
|
||||
_current_plugin,
|
||||
_module_name_to_plugin_name,
|
||||
)
|
||||
|
||||
|
||||
class PluginManager:
|
||||
"""插件管理器。
|
||||
|
||||
参数:
|
||||
plugins: 独立插件模块名集合。
|
||||
search_path: 插件搜索路径(文件夹)。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
plugins: Optional[Iterable[str]] = None,
|
||||
search_path: Optional[Iterable[str]] = None,
|
||||
):
|
||||
|
||||
# simple plugin not in search path
|
||||
self.plugins: Set[str] = set(plugins or [])
|
||||
self.search_path: Set[str] = set(search_path or [])
|
||||
|
||||
# cache plugins
|
||||
self.searched_plugins: Dict[str, Path] = {}
|
||||
self.list_plugins()
|
||||
self._third_party_plugin_names: Dict[str, str] = {}
|
||||
self._searched_plugin_names: Dict[str, Path] = {}
|
||||
self.prepare_plugins()
|
||||
|
||||
@property
|
||||
def third_party_plugins(self) -> Set[str]:
|
||||
"""返回所有独立插件名称。"""
|
||||
return set(self._third_party_plugin_names.keys())
|
||||
|
||||
@property
|
||||
def searched_plugins(self) -> Set[str]:
|
||||
"""返回已搜索到的插件名称。"""
|
||||
return set(self._searched_plugin_names.keys())
|
||||
|
||||
@property
|
||||
def available_plugins(self) -> Set[str]:
|
||||
"""返回当前插件管理器中可用的插件名称。"""
|
||||
return self.third_party_plugins | self.searched_plugins
|
||||
|
||||
def _path_to_module_name(self, path: Path) -> str:
|
||||
rel_path = path.resolve().relative_to(Path(".").resolve())
|
||||
@ -44,48 +73,51 @@ class PluginManager:
|
||||
else:
|
||||
return ".".join(rel_path.parts[:-1] + (rel_path.stem,))
|
||||
|
||||
def _previous_plugins(self) -> List[str]:
|
||||
def _previous_plugins(self) -> Set[str]:
|
||||
_pre_managers: List[PluginManager]
|
||||
if self in _managers:
|
||||
_pre_managers = _managers[: _managers.index(self)]
|
||||
else:
|
||||
_pre_managers = _managers[:]
|
||||
|
||||
return [
|
||||
*chain.from_iterable(
|
||||
[
|
||||
*map(lambda x: x.rsplit(".", 1)[-1], manager.plugins),
|
||||
*manager.searched_plugins.keys(),
|
||||
]
|
||||
for manager in _pre_managers
|
||||
)
|
||||
]
|
||||
return {
|
||||
*chain.from_iterable(manager.available_plugins for manager in _pre_managers)
|
||||
}
|
||||
|
||||
def prepare_plugins(self) -> Set[str]:
|
||||
"""搜索插件并缓存插件名称。"""
|
||||
|
||||
def list_plugins(self) -> Set[str]:
|
||||
# get all previous ready to load plugins
|
||||
previous_plugins = self._previous_plugins()
|
||||
searched_plugins: Dict[str, Path] = {}
|
||||
third_party_plugins: Set[str] = set()
|
||||
third_party_plugins: Dict[str, str] = {}
|
||||
|
||||
# check third party plugins
|
||||
for plugin in self.plugins:
|
||||
name = plugin.rsplit(".", 1)[-1]
|
||||
name = _module_name_to_plugin_name(plugin)
|
||||
if name in third_party_plugins or name in previous_plugins:
|
||||
raise RuntimeError(
|
||||
f"Plugin already exists: {name}! Check your plugin name"
|
||||
)
|
||||
third_party_plugins.add(plugin)
|
||||
third_party_plugins[name] = plugin
|
||||
|
||||
self._third_party_plugin_names = third_party_plugins
|
||||
|
||||
# check plugins in search path
|
||||
for module_info in pkgutil.iter_modules(self.search_path):
|
||||
# ignore if startswith "_"
|
||||
if module_info.name.startswith("_"):
|
||||
continue
|
||||
|
||||
if (
|
||||
module_info.name in searched_plugins.keys()
|
||||
module_info.name in searched_plugins
|
||||
or module_info.name in previous_plugins
|
||||
or module_info.name in third_party_plugins
|
||||
):
|
||||
raise RuntimeError(
|
||||
f"Plugin already exists: {module_info.name}! Check your plugin name"
|
||||
)
|
||||
|
||||
module_spec = module_info.module_finder.find_spec(module_info.name, None)
|
||||
if not module_spec:
|
||||
continue
|
||||
@ -94,17 +126,27 @@ class PluginManager:
|
||||
continue
|
||||
searched_plugins[module_info.name] = Path(module_path).resolve()
|
||||
|
||||
self.searched_plugins = searched_plugins
|
||||
self._searched_plugin_names = searched_plugins
|
||||
|
||||
return third_party_plugins | set(self.searched_plugins.keys())
|
||||
return self.available_plugins
|
||||
|
||||
def load_plugin(self, name: str) -> Optional[Plugin]:
|
||||
"""加载指定插件。
|
||||
|
||||
对于独立插件,可以使用完整插件模块名或者插件名称。
|
||||
|
||||
参数:
|
||||
name: 插件名称。
|
||||
"""
|
||||
|
||||
try:
|
||||
if name in self.plugins:
|
||||
module = importlib.import_module(name)
|
||||
elif name in self.searched_plugins:
|
||||
elif name in self._third_party_plugin_names:
|
||||
module = importlib.import_module(self._third_party_plugin_names[name])
|
||||
elif name in self._searched_plugin_names:
|
||||
module = importlib.import_module(
|
||||
self._path_to_module_name(self.searched_plugins[name])
|
||||
self._path_to_module_name(self._searched_plugin_names[name])
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(f"Plugin not found: {name}! Check your plugin name")
|
||||
@ -125,8 +167,10 @@ class PluginManager:
|
||||
)
|
||||
|
||||
def load_all_plugins(self) -> Set[Plugin]:
|
||||
"""加载所有可用插件。"""
|
||||
|
||||
return set(
|
||||
filter(None, (self.load_plugin(name) for name in self.list_plugins()))
|
||||
filter(None, (self.load_plugin(name) for name in self.available_plugins))
|
||||
)
|
||||
|
||||
|
||||
@ -147,9 +191,10 @@ class PluginFinder(MetaPathFinder):
|
||||
module_path = Path(module_origin).resolve()
|
||||
|
||||
for manager in reversed(_managers):
|
||||
# use path instead of name in case of submodule name conflict
|
||||
if (
|
||||
fullname in manager.plugins
|
||||
or module_path in manager.searched_plugins.values()
|
||||
or module_path in manager._searched_plugin_names.values()
|
||||
):
|
||||
module_spec.loader = PluginLoader(manager, fullname, module_origin)
|
||||
return module_spec
|
||||
@ -173,7 +218,11 @@ class PluginLoader(SourceFileLoader):
|
||||
if self.loaded:
|
||||
return
|
||||
|
||||
# create plugin before executing
|
||||
plugin = _new_plugin(self.name, module, self.manager)
|
||||
setattr(module, "__plugin__", plugin)
|
||||
|
||||
# detect parent plugin before entering current plugin context
|
||||
parent_plugin = _current_plugin.get()
|
||||
if parent_plugin and _managers.index(parent_plugin.manager) < _managers.index(
|
||||
self.manager
|
||||
@ -181,21 +230,18 @@ class PluginLoader(SourceFileLoader):
|
||||
plugin.parent_plugin = parent_plugin
|
||||
parent_plugin.sub_plugins.add(plugin)
|
||||
|
||||
# enter plugin context
|
||||
_plugin_token = _current_plugin.set(plugin)
|
||||
|
||||
setattr(module, "__plugin__", plugin)
|
||||
try:
|
||||
super().exec_module(module)
|
||||
except Exception:
|
||||
_revert_plugin(plugin)
|
||||
raise
|
||||
finally:
|
||||
# leave plugin context
|
||||
_current_plugin.reset(_plugin_token)
|
||||
|
||||
# try:
|
||||
# super().exec_module(module)
|
||||
# except Exception as e:
|
||||
# raise ImportError(
|
||||
# f"Error when executing module {module_name} from {module.__file__}."
|
||||
# ) from e
|
||||
super().exec_module(module)
|
||||
|
||||
_confirm_plugin(plugin)
|
||||
|
||||
_current_plugin.reset(_plugin_token)
|
||||
return
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user