🐛 fix import hook export

This commit is contained in:
yanyongyu
2021-03-31 20:38:00 +08:00
parent d1e8925fe0
commit ca08c56df7
3 changed files with 80 additions and 69 deletions

View File

@ -20,6 +20,7 @@ from nonebot.permission import Permission
from nonebot.typing import T_State, T_StateFactory, T_Handler, T_RuleChecker
from nonebot.rule import Rule, startswith, endswith, keyword, command, shell_command, ArgumentParser, regex
from .export import Export, export, _export
from .manager import PluginManager, _current_plugin
if TYPE_CHECKING:
@ -32,55 +33,9 @@ plugins: Dict[str, "Plugin"] = {}
"""
PLUGIN_NAMESPACE = "nonebot.loaded_plugins"
_export: ContextVar["Export"] = ContextVar("_export")
# FIXME: tmp matchers context var will be removed
_plugin_matchers: Dict[str, Set[Type[Matcher]]] = defaultdict(set)
class Export(dict):
"""
:说明:
插件导出内容以使得其他插件可以获得。
:示例:
.. code-block:: python
nonebot.export().default = "bar"
@nonebot.export()
def some_function():
pass
# this doesn't work before python 3.9
# use
# export = nonebot.export(); @export.sub
# instead
# See also PEP-614: https://www.python.org/dev/peps/pep-0614/
@nonebot.export().sub
def something_else():
pass
"""
def __call__(self, func, **kwargs):
self[func.__name__] = func
self.update(kwargs)
return func
def __setitem__(self, key, value):
super().__setitem__(key,
Export(value) if isinstance(value, dict) else value)
def __setattr__(self, name, value):
self[name] = Export(value) if isinstance(value, dict) else value
def __getattr__(self, name):
if name not in self:
self[name] = Export()
return self[name]
@dataclass(eq=False)
class Plugin(object):
"""存储插件信息"""
@ -966,15 +921,14 @@ def _load_plugin(manager: PluginManager, plugin_name: str) -> Optional[Plugin]:
if plugin_name.startswith("_"):
return None
_export.set(Export())
if plugin_name in plugins:
return None
try:
module = manager.load_plugin(plugin_name)
plugin = Plugin(plugin_name, module, _export.get())
plugin = Plugin(plugin_name, module,
getattr(module, "__export__", Export()))
plugins[plugin_name] = plugin
logger.opt(
colors=True).info(f'Succeeded to import "<y>{plugin_name}</y>"')
@ -1153,19 +1107,6 @@ def get_loaded_plugins() -> Set[Plugin]:
return set(plugins.values())
def export() -> Export:
"""
:说明:
获取插件的导出内容对象
:返回:
- ``Export``
"""
return _export.get()
def require(name: str) -> Optional[Export]:
"""
:说明: