diff --git a/nonebot/plugin/load.py b/nonebot/plugin/load.py index 2d2c55bb..03d5dbde 100644 --- a/nonebot/plugin/load.py +++ b/nonebot/plugin/load.py @@ -137,6 +137,12 @@ def load_builtin_plugins(*plugins) -> Set[Plugin]: return load_all_plugins([f"nonebot.plugins.{p}" for p in plugins], []) +def _find_manager_by_name(name: str) -> Optional[PluginManager]: + for manager in reversed(_managers): + if name in manager.plugins or name in manager.searched_plugins: + return manager + + def require(name: str) -> Export: """获取一个插件的导出内容。 @@ -148,7 +154,13 @@ def require(name: str) -> Export: 异常: RuntimeError: 插件无法加载 """ - plugin = get_plugin(name) or load_plugin(name) + plugin = get_plugin(name.rsplit(".", 1)[-1]) if not plugin: - raise RuntimeError(f'Cannot load plugin "{name}"!') + manager = _find_manager_by_name(name) + if manager: + plugin = manager.load_plugin(name) + else: + plugin = load_plugin(name) + if not plugin: + raise RuntimeError(f'Cannot load plugin "{name}"!') return plugin.export diff --git a/nonebot/plugin/manager.py b/nonebot/plugin/manager.py index c46c77c4..7e8b2826 100644 --- a/nonebot/plugin/manager.py +++ b/nonebot/plugin/manager.py @@ -53,7 +53,10 @@ class PluginManager: return [ *chain.from_iterable( - [*manager.plugins, *manager.searched_plugins.keys()] + [ + *map(lambda x: x.rsplit(".", 1)[-1], manager.plugins), + *manager.searched_plugins.keys(), + ] for manager in _pre_managers ) ] @@ -65,7 +68,7 @@ class PluginManager: third_party_plugins: Set[str] = set() for plugin in self.plugins: - name = plugin.rsplit(".", 1)[-1] if "." in plugin else plugin + name = plugin.rsplit(".", 1)[-1] if name in third_party_plugins or name in previous_plugins: raise RuntimeError( f"Plugin already exists: {name}! Check your plugin name" @@ -95,21 +98,27 @@ class PluginManager: return third_party_plugins | set(self.searched_plugins.keys()) - def load_plugin(self, name) -> Optional[Plugin]: + def load_plugin(self, name: str) -> Optional[Plugin]: try: if name in self.plugins: module = importlib.import_module(name) - elif name not in self.searched_plugins: - raise RuntimeError(f"Plugin not found: {name}! Check your plugin name") - else: + elif name in self.searched_plugins: module = importlib.import_module( self._path_to_module_name(self.searched_plugins[name]) ) + else: + raise RuntimeError(f"Plugin not found: {name}! Check your plugin name") logger.opt(colors=True).success( f'Succeeded to import "{escape_tag(name)}"' ) - return getattr(module, "__plugin__", None) + plugin = getattr(module, "__plugin__", None) + if plugin is None: + raise RuntimeError( + f"Module {module.__name__} is not loaded as a plugin! " + "Make sure not to import it before loading." + ) + return plugin except Exception as e: logger.opt(colors=True, exception=e).error( f'Failed to import "{escape_tag(name)}"' @@ -129,7 +138,6 @@ class PluginFinder(MetaPathFinder): target: Optional[ModuleType] = None, ): if _managers: - index = -1 module_spec = PathFinder.find_spec(fullname, path, target) if not module_spec: return @@ -138,17 +146,13 @@ class PluginFinder(MetaPathFinder): return module_path = Path(module_origin).resolve() - while -index <= len(_managers): - manager = _managers[index] - + for manager in reversed(_managers): if ( fullname in manager.plugins or module_path in manager.searched_plugins.values() ): module_spec.loader = PluginLoader(manager, fullname, module_origin) return module_spec - - index -= 1 return diff --git a/tests/test_init.py b/tests/test_init.py index 03da1755..6d36dccd 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -1,12 +1,7 @@ import os -import sys -from typing import TYPE_CHECKING, Set import pytest -if TYPE_CHECKING: - from nonebot.plugin import Plugin - os.environ["CONFIG_FROM_ENV"] = '{"test": "test"}' @@ -74,29 +69,3 @@ async def test_get(monkeypatch: pytest.MonkeyPatch, nonebug_clear): assert get_bot() == "test" assert get_bot("test") == "test" assert get_bots() == {"test": "test"} - - -@pytest.mark.asyncio -async def test_load_plugin(load_plugin: Set["Plugin"]): - import nonebot - - loaded_plugins = set( - plugin for plugin in nonebot.get_loaded_plugins() if not plugin.parent_plugin - ) - assert loaded_plugins == load_plugin - plugin = nonebot.get_plugin("export") - assert plugin - assert plugin.module_name == "plugins.export" - assert "plugins.export" in sys.modules - - try: - nonebot.load_plugin("plugins.export") - assert False - except RuntimeError: - assert True - - try: - nonebot.load_plugin("some_plugin_no_exist") - assert False - except Exception: - assert nonebot.get_plugin("some_plugin_no_exist") is None diff --git a/tests/test_plugin/test_load.py b/tests/test_plugin/test_load.py new file mode 100644 index 00000000..c0c0c8f6 --- /dev/null +++ b/tests/test_plugin/test_load.py @@ -0,0 +1,89 @@ +import sys +from typing import TYPE_CHECKING, Set + +import pytest +from nonebug import App + +if TYPE_CHECKING: + from nonebot.plugin import Plugin + + +@pytest.mark.asyncio +async def test_load_plugin(load_plugin: Set["Plugin"]): + import nonebot + + loaded_plugins = set( + plugin for plugin in nonebot.get_loaded_plugins() if not plugin.parent_plugin + ) + assert loaded_plugins == load_plugin + plugin = nonebot.get_plugin("export") + assert plugin + assert plugin.module_name == "plugins.export" + assert "plugins.export" in sys.modules + + try: + nonebot.load_plugin("plugins.export") + assert False + except RuntimeError: + assert True + + assert nonebot.load_plugin("some_plugin_not_exist") is None + + +@pytest.mark.asyncio +async def test_require_loaded(app: App, monkeypatch: pytest.MonkeyPatch): + import nonebot + + def _patched_find(name: str): + assert False + + monkeypatch.setattr("nonebot.plugin.load._find_manager_by_name", _patched_find) + + nonebot.load_plugin("plugins.export") + + nonebot.require("plugins.export") + + +@pytest.mark.asyncio +async def test_require_not_loaded(app: App, monkeypatch: pytest.MonkeyPatch): + import nonebot + from nonebot.plugin import _managers + from nonebot.plugin.manager import PluginManager + + m = PluginManager(["plugins.export"]) + _managers.append(m) + + origin_load = PluginManager.load_plugin + + def _patched_load(self: PluginManager, name: str): + assert self is m + return origin_load(self, name) + + monkeypatch.setattr(PluginManager, "load_plugin", _patched_load) + + nonebot.require("plugins.export") + + assert len(_managers) == 1 + + +@pytest.mark.asyncio +async def test_require_not_declared(app: App): + import nonebot + from nonebot.plugin import _managers + + nonebot.require("plugins.export") + + assert len(_managers) == 1 + assert _managers[-1].plugins == {"plugins.export"} + + +@pytest.mark.asyncio +async def test_require_not_found(app: App): + import nonebot + from nonebot.plugin import _managers + + try: + nonebot.require("some_plugin_not_exist") + assert False + except RuntimeError: + assert True