improve plugin system (#1011)

This commit is contained in:
Ju4tCode
2022-05-26 16:35:47 +08:00
committed by GitHub
parent 579839f2a4
commit fa3ed2b58c
15 changed files with 254 additions and 106 deletions

View File

@ -19,23 +19,52 @@ from typing import Set, Dict, List, Union, Iterable, Optional, Sequence
from nonebot.log import logger
from nonebot.utils import escape_tag
from . import _managers, _current_plugin
from .plugin import Plugin, _new_plugin, _confirm_plugin
from .plugin import Plugin
from . import (
_managers,
_new_plugin,
_revert_plugin,
_current_plugin,
_module_name_to_plugin_name,
)
class PluginManager:
"""插件管理器。
参数:
plugins: 独立插件模块名集合。
search_path: 插件搜索路径(文件夹)。
"""
def __init__(
self,
plugins: Optional[Iterable[str]] = None,
search_path: Optional[Iterable[str]] = None,
):
# simple plugin not in search path
self.plugins: Set[str] = set(plugins or [])
self.search_path: Set[str] = set(search_path or [])
# cache plugins
self.searched_plugins: Dict[str, Path] = {}
self.list_plugins()
self._third_party_plugin_names: Dict[str, str] = {}
self._searched_plugin_names: Dict[str, Path] = {}
self.prepare_plugins()
@property
def third_party_plugins(self) -> Set[str]:
"""返回所有独立插件名称。"""
return set(self._third_party_plugin_names.keys())
@property
def searched_plugins(self) -> Set[str]:
"""返回已搜索到的插件名称。"""
return set(self._searched_plugin_names.keys())
@property
def available_plugins(self) -> Set[str]:
"""返回当前插件管理器中可用的插件名称。"""
return self.third_party_plugins | self.searched_plugins
def _path_to_module_name(self, path: Path) -> str:
rel_path = path.resolve().relative_to(Path(".").resolve())
@ -44,48 +73,51 @@ class PluginManager:
else:
return ".".join(rel_path.parts[:-1] + (rel_path.stem,))
def _previous_plugins(self) -> List[str]:
def _previous_plugins(self) -> Set[str]:
_pre_managers: List[PluginManager]
if self in _managers:
_pre_managers = _managers[: _managers.index(self)]
else:
_pre_managers = _managers[:]
return [
*chain.from_iterable(
[
*map(lambda x: x.rsplit(".", 1)[-1], manager.plugins),
*manager.searched_plugins.keys(),
]
for manager in _pre_managers
)
]
return {
*chain.from_iterable(manager.available_plugins for manager in _pre_managers)
}
def prepare_plugins(self) -> Set[str]:
"""搜索插件并缓存插件名称。"""
def list_plugins(self) -> Set[str]:
# get all previous ready to load plugins
previous_plugins = self._previous_plugins()
searched_plugins: Dict[str, Path] = {}
third_party_plugins: Set[str] = set()
third_party_plugins: Dict[str, str] = {}
# check third party plugins
for plugin in self.plugins:
name = plugin.rsplit(".", 1)[-1]
name = _module_name_to_plugin_name(plugin)
if name in third_party_plugins or name in previous_plugins:
raise RuntimeError(
f"Plugin already exists: {name}! Check your plugin name"
)
third_party_plugins.add(plugin)
third_party_plugins[name] = plugin
self._third_party_plugin_names = third_party_plugins
# check plugins in search path
for module_info in pkgutil.iter_modules(self.search_path):
# ignore if startswith "_"
if module_info.name.startswith("_"):
continue
if (
module_info.name in searched_plugins.keys()
module_info.name in searched_plugins
or module_info.name in previous_plugins
or module_info.name in third_party_plugins
):
raise RuntimeError(
f"Plugin already exists: {module_info.name}! Check your plugin name"
)
module_spec = module_info.module_finder.find_spec(module_info.name, None)
if not module_spec:
continue
@ -94,17 +126,27 @@ class PluginManager:
continue
searched_plugins[module_info.name] = Path(module_path).resolve()
self.searched_plugins = searched_plugins
self._searched_plugin_names = searched_plugins
return third_party_plugins | set(self.searched_plugins.keys())
return self.available_plugins
def load_plugin(self, name: str) -> Optional[Plugin]:
"""加载指定插件。
对于独立插件,可以使用完整插件模块名或者插件名称。
参数:
name: 插件名称。
"""
try:
if name in self.plugins:
module = importlib.import_module(name)
elif name in self.searched_plugins:
elif name in self._third_party_plugin_names:
module = importlib.import_module(self._third_party_plugin_names[name])
elif name in self._searched_plugin_names:
module = importlib.import_module(
self._path_to_module_name(self.searched_plugins[name])
self._path_to_module_name(self._searched_plugin_names[name])
)
else:
raise RuntimeError(f"Plugin not found: {name}! Check your plugin name")
@ -125,8 +167,10 @@ class PluginManager:
)
def load_all_plugins(self) -> Set[Plugin]:
"""加载所有可用插件。"""
return set(
filter(None, (self.load_plugin(name) for name in self.list_plugins()))
filter(None, (self.load_plugin(name) for name in self.available_plugins))
)
@ -147,9 +191,10 @@ class PluginFinder(MetaPathFinder):
module_path = Path(module_origin).resolve()
for manager in reversed(_managers):
# use path instead of name in case of submodule name conflict
if (
fullname in manager.plugins
or module_path in manager.searched_plugins.values()
or module_path in manager._searched_plugin_names.values()
):
module_spec.loader = PluginLoader(manager, fullname, module_origin)
return module_spec
@ -173,7 +218,11 @@ class PluginLoader(SourceFileLoader):
if self.loaded:
return
# create plugin before executing
plugin = _new_plugin(self.name, module, self.manager)
setattr(module, "__plugin__", plugin)
# detect parent plugin before entering current plugin context
parent_plugin = _current_plugin.get()
if parent_plugin and _managers.index(parent_plugin.manager) < _managers.index(
self.manager
@ -181,21 +230,18 @@ class PluginLoader(SourceFileLoader):
plugin.parent_plugin = parent_plugin
parent_plugin.sub_plugins.add(plugin)
# enter plugin context
_plugin_token = _current_plugin.set(plugin)
setattr(module, "__plugin__", plugin)
try:
super().exec_module(module)
except Exception:
_revert_plugin(plugin)
raise
finally:
# leave plugin context
_current_plugin.reset(_plugin_token)
# try:
# super().exec_module(module)
# except Exception as e:
# raise ImportError(
# f"Error when executing module {module_name} from {module.__file__}."
# ) from e
super().exec_module(module)
_confirm_plugin(plugin)
_current_plugin.reset(_plugin_token)
return