⚗️ add export require option

This commit is contained in:
yanyongyu
2020-11-21 20:40:09 +08:00
parent b36f95862a
commit 9373bd09ed
6 changed files with 145 additions and 37 deletions

View File

@ -11,6 +11,7 @@ import pkgutil
import importlib
from dataclasses import dataclass
from importlib._bootstrap import _load
from contextvars import Context, ContextVar, copy_context
from nonebot.log import logger
from nonebot.matcher import Matcher
@ -25,7 +26,45 @@ plugins: Dict[str, "Plugin"] = {}
:说明: 已加载的插件
"""
_tmp_matchers: Set[Type[Matcher]] = set()
_tmp_matchers: ContextVar[Set[Type[Matcher]]] = ContextVar("_tmp_matchers")
_export: ContextVar["Export"] = ContextVar("_export")
class Export(dict):
"""
:说明:
插件导出内容以使得其他插件可以获得。
:示例:
.. code-block:: python
nonebot.export().default = "bar"
@nonebot.export()
def some_function():
pass
@nonebot.export().sub
def something_else():
pass
"""
def __call__(self, func, **kwargs):
self[func.__name__] = func
self.update(kwargs)
return func
def __setitem__(self, key, value):
super().__setitem__(key,
Export(value) if isinstance(value, dict) else value)
def __setattr__(self, name, value):
self[name] = Export(value) if isinstance(value, dict) else value
def __getattr__(self, name):
if name not in self:
self[name] = Export()
return self[name]
@dataclass(eq=False)
@ -46,6 +85,7 @@ class Plugin(object):
- **类型**: ``Set[Type[Matcher]]``
- **说明**: 插件内定义的 ``Matcher``
"""
export: Export
def on(type: str = "",
@ -80,7 +120,7 @@ def on(type: str = "",
block=block,
handlers=handlers,
default_state=state)
_tmp_matchers.add(matcher)
_tmp_matchers.get().add(matcher)
return matcher
@ -112,7 +152,7 @@ def on_metaevent(rule: Optional[Union[Rule, RuleChecker]] = None,
block=block,
handlers=handlers,
default_state=state)
_tmp_matchers.add(matcher)
_tmp_matchers.get().add(matcher)
return matcher
@ -146,7 +186,7 @@ def on_message(rule: Optional[Union[Rule, RuleChecker]] = None,
block=block,
handlers=handlers,
default_state=state)
_tmp_matchers.add(matcher)
_tmp_matchers.get().add(matcher)
return matcher
@ -178,7 +218,7 @@ def on_notice(rule: Optional[Union[Rule, RuleChecker]] = None,
block=block,
handlers=handlers,
default_state=state)
_tmp_matchers.add(matcher)
_tmp_matchers.get().add(matcher)
return matcher
@ -210,7 +250,7 @@ def on_request(rule: Optional[Union[Rule, RuleChecker]] = None,
block=block,
handlers=handlers,
default_state=state)
_tmp_matchers.add(matcher)
_tmp_matchers.get().add(matcher)
return matcher
@ -387,27 +427,35 @@ def load_plugin(module_path: str) -> Optional[Plugin]:
:返回:
- ``Optional[Plugin]``
"""
try:
_tmp_matchers.clear()
if module_path in plugins:
return plugins[module_path]
elif module_path in sys.modules:
logger.warning(
f"Module {module_path} has been loaded by other plugins! Ignored"
def _load_plugin(module_path: str) -> Optional[Plugin]:
try:
_tmp_matchers.set(set())
_export.set(Export())
if module_path in plugins:
return plugins[module_path]
elif module_path in sys.modules:
logger.warning(
f"Module {module_path} has been loaded by other plugins! Ignored"
)
return
module = importlib.import_module(module_path)
for m in _tmp_matchers.get():
m.module = module_path
plugin = Plugin(module_path, module, _tmp_matchers.get(),
_export.get())
plugins[module_path] = plugin
logger.opt(
colors=True).info(f'Succeeded to import "<y>{module_path}</y>"')
return plugin
except Exception as e:
logger.opt(colors=True, exception=e).error(
f'<r><bg #f8bbd0>Failed to import "{module_path}"</bg #f8bbd0></r>'
)
return
module = importlib.import_module(module_path)
for m in _tmp_matchers:
m.module = module_path
plugin = Plugin(module_path, module, _tmp_matchers.copy())
plugins[module_path] = plugin
logger.opt(
colors=True).info(f'Succeeded to import "<y>{module_path}</y>"')
return plugin
except Exception as e:
logger.opt(colors=True, exception=e).error(
f'<r><bg #f8bbd0>Failed to import "{module_path}"</bg #f8bbd0></r>')
return None
return None
context: Context = copy_context()
return context.run(_load_plugin, module_path)
def load_plugins(*plugin_dir: str) -> Set[Plugin]:
@ -419,33 +467,42 @@ def load_plugins(*plugin_dir: str) -> Set[Plugin]:
:返回:
- ``Set[Plugin]``
"""
loaded_plugins = set()
for module_info in pkgutil.iter_modules(plugin_dir):
_tmp_matchers.clear()
def _load_plugin(module_info) -> Optional[Plugin]:
_tmp_matchers.set(set())
_export.set(Export())
name = module_info.name
if name.startswith("_"):
continue
return
spec = module_info.module_finder.find_spec(name, None)
if spec.name in plugins:
continue
return
elif spec.name in sys.modules:
logger.warning(
f"Module {spec.name} has been loaded by other plugin! Ignored")
continue
return
try:
module = _load(spec)
for m in _tmp_matchers:
for m in _tmp_matchers.get():
m.module = name
plugin = Plugin(name, module, _tmp_matchers.copy())
plugin = Plugin(name, module, _tmp_matchers.get(), _export.get())
plugins[name] = plugin
loaded_plugins.add(plugin)
logger.opt(colors=True).info(f'Succeeded to import "<y>{name}</y>"')
return plugin
except Exception as e:
logger.opt(colors=True, exception=e).error(
f'<r><bg #f8bbd0>Failed to import "{name}"</bg #f8bbd0></r>')
return None
loaded_plugins = set()
for module_info in pkgutil.iter_modules(plugin_dir):
context: Context = copy_context()
result = context.run(_load_plugin, module_info)
if result:
loaded_plugins.add(result)
return loaded_plugins
@ -479,3 +536,12 @@ def get_loaded_plugins() -> Set[Plugin]:
- ``Set[Plugin]``
"""
return set(plugins.values())
def export() -> Export:
return _export.get()
def require(name: str) -> Optional[Export]:
plugin = get_plugin(name)
return plugin.export if plugin else None