diff --git a/nonebot/plugin/load.py b/nonebot/plugin/load.py
index 2d2c55bb..03d5dbde 100644
--- a/nonebot/plugin/load.py
+++ b/nonebot/plugin/load.py
@@ -137,6 +137,12 @@ def load_builtin_plugins(*plugins) -> Set[Plugin]:
return load_all_plugins([f"nonebot.plugins.{p}" for p in plugins], [])
+def _find_manager_by_name(name: str) -> Optional[PluginManager]:
+ for manager in reversed(_managers):
+ if name in manager.plugins or name in manager.searched_plugins:
+ return manager
+
+
def require(name: str) -> Export:
"""获取一个插件的导出内容。
@@ -148,7 +154,13 @@ def require(name: str) -> Export:
异常:
RuntimeError: 插件无法加载
"""
- plugin = get_plugin(name) or load_plugin(name)
+ plugin = get_plugin(name.rsplit(".", 1)[-1])
if not plugin:
- raise RuntimeError(f'Cannot load plugin "{name}"!')
+ manager = _find_manager_by_name(name)
+ if manager:
+ plugin = manager.load_plugin(name)
+ else:
+ plugin = load_plugin(name)
+ if not plugin:
+ raise RuntimeError(f'Cannot load plugin "{name}"!')
return plugin.export
diff --git a/nonebot/plugin/manager.py b/nonebot/plugin/manager.py
index c46c77c4..7e8b2826 100644
--- a/nonebot/plugin/manager.py
+++ b/nonebot/plugin/manager.py
@@ -53,7 +53,10 @@ class PluginManager:
return [
*chain.from_iterable(
- [*manager.plugins, *manager.searched_plugins.keys()]
+ [
+ *map(lambda x: x.rsplit(".", 1)[-1], manager.plugins),
+ *manager.searched_plugins.keys(),
+ ]
for manager in _pre_managers
)
]
@@ -65,7 +68,7 @@ class PluginManager:
third_party_plugins: Set[str] = set()
for plugin in self.plugins:
- name = plugin.rsplit(".", 1)[-1] if "." in plugin else plugin
+ name = plugin.rsplit(".", 1)[-1]
if name in third_party_plugins or name in previous_plugins:
raise RuntimeError(
f"Plugin already exists: {name}! Check your plugin name"
@@ -95,21 +98,27 @@ class PluginManager:
return third_party_plugins | set(self.searched_plugins.keys())
- def load_plugin(self, name) -> Optional[Plugin]:
+ def load_plugin(self, name: str) -> Optional[Plugin]:
try:
if name in self.plugins:
module = importlib.import_module(name)
- elif name not in self.searched_plugins:
- raise RuntimeError(f"Plugin not found: {name}! Check your plugin name")
- else:
+ elif name in self.searched_plugins:
module = importlib.import_module(
self._path_to_module_name(self.searched_plugins[name])
)
+ else:
+ raise RuntimeError(f"Plugin not found: {name}! Check your plugin name")
logger.opt(colors=True).success(
f'Succeeded to import "{escape_tag(name)}"'
)
- return getattr(module, "__plugin__", None)
+ plugin = getattr(module, "__plugin__", None)
+ if plugin is None:
+ raise RuntimeError(
+ f"Module {module.__name__} is not loaded as a plugin! "
+ "Make sure not to import it before loading."
+ )
+ return plugin
except Exception as e:
logger.opt(colors=True, exception=e).error(
f'Failed to import "{escape_tag(name)}"'
@@ -129,7 +138,6 @@ class PluginFinder(MetaPathFinder):
target: Optional[ModuleType] = None,
):
if _managers:
- index = -1
module_spec = PathFinder.find_spec(fullname, path, target)
if not module_spec:
return
@@ -138,17 +146,13 @@ class PluginFinder(MetaPathFinder):
return
module_path = Path(module_origin).resolve()
- while -index <= len(_managers):
- manager = _managers[index]
-
+ for manager in reversed(_managers):
if (
fullname in manager.plugins
or module_path in manager.searched_plugins.values()
):
module_spec.loader = PluginLoader(manager, fullname, module_origin)
return module_spec
-
- index -= 1
return
diff --git a/tests/test_init.py b/tests/test_init.py
index 03da1755..6d36dccd 100644
--- a/tests/test_init.py
+++ b/tests/test_init.py
@@ -1,12 +1,7 @@
import os
-import sys
-from typing import TYPE_CHECKING, Set
import pytest
-if TYPE_CHECKING:
- from nonebot.plugin import Plugin
-
os.environ["CONFIG_FROM_ENV"] = '{"test": "test"}'
@@ -74,29 +69,3 @@ async def test_get(monkeypatch: pytest.MonkeyPatch, nonebug_clear):
assert get_bot() == "test"
assert get_bot("test") == "test"
assert get_bots() == {"test": "test"}
-
-
-@pytest.mark.asyncio
-async def test_load_plugin(load_plugin: Set["Plugin"]):
- import nonebot
-
- loaded_plugins = set(
- plugin for plugin in nonebot.get_loaded_plugins() if not plugin.parent_plugin
- )
- assert loaded_plugins == load_plugin
- plugin = nonebot.get_plugin("export")
- assert plugin
- assert plugin.module_name == "plugins.export"
- assert "plugins.export" in sys.modules
-
- try:
- nonebot.load_plugin("plugins.export")
- assert False
- except RuntimeError:
- assert True
-
- try:
- nonebot.load_plugin("some_plugin_no_exist")
- assert False
- except Exception:
- assert nonebot.get_plugin("some_plugin_no_exist") is None
diff --git a/tests/test_plugin/test_load.py b/tests/test_plugin/test_load.py
new file mode 100644
index 00000000..c0c0c8f6
--- /dev/null
+++ b/tests/test_plugin/test_load.py
@@ -0,0 +1,89 @@
+import sys
+from typing import TYPE_CHECKING, Set
+
+import pytest
+from nonebug import App
+
+if TYPE_CHECKING:
+ from nonebot.plugin import Plugin
+
+
+@pytest.mark.asyncio
+async def test_load_plugin(load_plugin: Set["Plugin"]):
+ import nonebot
+
+ loaded_plugins = set(
+ plugin for plugin in nonebot.get_loaded_plugins() if not plugin.parent_plugin
+ )
+ assert loaded_plugins == load_plugin
+ plugin = nonebot.get_plugin("export")
+ assert plugin
+ assert plugin.module_name == "plugins.export"
+ assert "plugins.export" in sys.modules
+
+ try:
+ nonebot.load_plugin("plugins.export")
+ assert False
+ except RuntimeError:
+ assert True
+
+ assert nonebot.load_plugin("some_plugin_not_exist") is None
+
+
+@pytest.mark.asyncio
+async def test_require_loaded(app: App, monkeypatch: pytest.MonkeyPatch):
+ import nonebot
+
+ def _patched_find(name: str):
+ assert False
+
+ monkeypatch.setattr("nonebot.plugin.load._find_manager_by_name", _patched_find)
+
+ nonebot.load_plugin("plugins.export")
+
+ nonebot.require("plugins.export")
+
+
+@pytest.mark.asyncio
+async def test_require_not_loaded(app: App, monkeypatch: pytest.MonkeyPatch):
+ import nonebot
+ from nonebot.plugin import _managers
+ from nonebot.plugin.manager import PluginManager
+
+ m = PluginManager(["plugins.export"])
+ _managers.append(m)
+
+ origin_load = PluginManager.load_plugin
+
+ def _patched_load(self: PluginManager, name: str):
+ assert self is m
+ return origin_load(self, name)
+
+ monkeypatch.setattr(PluginManager, "load_plugin", _patched_load)
+
+ nonebot.require("plugins.export")
+
+ assert len(_managers) == 1
+
+
+@pytest.mark.asyncio
+async def test_require_not_declared(app: App):
+ import nonebot
+ from nonebot.plugin import _managers
+
+ nonebot.require("plugins.export")
+
+ assert len(_managers) == 1
+ assert _managers[-1].plugins == {"plugins.export"}
+
+
+@pytest.mark.asyncio
+async def test_require_not_found(app: App):
+ import nonebot
+ from nonebot.plugin import _managers
+
+ try:
+ nonebot.require("some_plugin_not_exist")
+ assert False
+ except RuntimeError:
+ assert True