mirror of
				https://github.com/nonebot/nonebot2.git
				synced 2025-10-26 12:36:40 +00:00 
			
		
		
		
	⚗️ add export require option
This commit is contained in:
		| @@ -11,6 +11,7 @@ import pkgutil | ||||
| import importlib | ||||
| from dataclasses import dataclass | ||||
| from importlib._bootstrap import _load | ||||
| from contextvars import Context, ContextVar, copy_context | ||||
|  | ||||
| from nonebot.log import logger | ||||
| from nonebot.matcher import Matcher | ||||
| @@ -25,7 +26,45 @@ plugins: Dict[str, "Plugin"] = {} | ||||
| :说明: 已加载的插件 | ||||
| """ | ||||
|  | ||||
| _tmp_matchers: Set[Type[Matcher]] = set() | ||||
| _tmp_matchers: ContextVar[Set[Type[Matcher]]] = ContextVar("_tmp_matchers") | ||||
| _export: ContextVar["Export"] = ContextVar("_export") | ||||
|  | ||||
|  | ||||
| class Export(dict): | ||||
|     """ | ||||
|     :说明: | ||||
|       插件导出内容以使得其他插件可以获得。 | ||||
|     :示例: | ||||
|  | ||||
|     .. code-block:: python | ||||
|  | ||||
|         nonebot.export().default = "bar" | ||||
|  | ||||
|         @nonebot.export() | ||||
|         def some_function(): | ||||
|             pass | ||||
|  | ||||
|         @nonebot.export().sub | ||||
|         def something_else(): | ||||
|             pass | ||||
|     """ | ||||
|  | ||||
|     def __call__(self, func, **kwargs): | ||||
|         self[func.__name__] = func | ||||
|         self.update(kwargs) | ||||
|         return func | ||||
|  | ||||
|     def __setitem__(self, key, value): | ||||
|         super().__setitem__(key, | ||||
|                             Export(value) if isinstance(value, dict) else value) | ||||
|  | ||||
|     def __setattr__(self, name, value): | ||||
|         self[name] = Export(value) if isinstance(value, dict) else value | ||||
|  | ||||
|     def __getattr__(self, name): | ||||
|         if name not in self: | ||||
|             self[name] = Export() | ||||
|         return self[name] | ||||
|  | ||||
|  | ||||
| @dataclass(eq=False) | ||||
| @@ -46,6 +85,7 @@ class Plugin(object): | ||||
|     - **类型**: ``Set[Type[Matcher]]`` | ||||
|     - **说明**: 插件内定义的 ``Matcher`` | ||||
|     """ | ||||
|     export: Export | ||||
|  | ||||
|  | ||||
| def on(type: str = "", | ||||
| @@ -80,7 +120,7 @@ def on(type: str = "", | ||||
|                           block=block, | ||||
|                           handlers=handlers, | ||||
|                           default_state=state) | ||||
|     _tmp_matchers.add(matcher) | ||||
|     _tmp_matchers.get().add(matcher) | ||||
|     return matcher | ||||
|  | ||||
|  | ||||
| @@ -112,7 +152,7 @@ def on_metaevent(rule: Optional[Union[Rule, RuleChecker]] = None, | ||||
|                           block=block, | ||||
|                           handlers=handlers, | ||||
|                           default_state=state) | ||||
|     _tmp_matchers.add(matcher) | ||||
|     _tmp_matchers.get().add(matcher) | ||||
|     return matcher | ||||
|  | ||||
|  | ||||
| @@ -146,7 +186,7 @@ def on_message(rule: Optional[Union[Rule, RuleChecker]] = None, | ||||
|                           block=block, | ||||
|                           handlers=handlers, | ||||
|                           default_state=state) | ||||
|     _tmp_matchers.add(matcher) | ||||
|     _tmp_matchers.get().add(matcher) | ||||
|     return matcher | ||||
|  | ||||
|  | ||||
| @@ -178,7 +218,7 @@ def on_notice(rule: Optional[Union[Rule, RuleChecker]] = None, | ||||
|                           block=block, | ||||
|                           handlers=handlers, | ||||
|                           default_state=state) | ||||
|     _tmp_matchers.add(matcher) | ||||
|     _tmp_matchers.get().add(matcher) | ||||
|     return matcher | ||||
|  | ||||
|  | ||||
| @@ -210,7 +250,7 @@ def on_request(rule: Optional[Union[Rule, RuleChecker]] = None, | ||||
|                           block=block, | ||||
|                           handlers=handlers, | ||||
|                           default_state=state) | ||||
|     _tmp_matchers.add(matcher) | ||||
|     _tmp_matchers.get().add(matcher) | ||||
|     return matcher | ||||
|  | ||||
|  | ||||
| @@ -387,27 +427,35 @@ def load_plugin(module_path: str) -> Optional[Plugin]: | ||||
|     :返回: | ||||
|       - ``Optional[Plugin]`` | ||||
|     """ | ||||
|     try: | ||||
|         _tmp_matchers.clear() | ||||
|         if module_path in plugins: | ||||
|             return plugins[module_path] | ||||
|         elif module_path in sys.modules: | ||||
|             logger.warning( | ||||
|                 f"Module {module_path} has been loaded by other plugins! Ignored" | ||||
|  | ||||
|     def _load_plugin(module_path: str) -> Optional[Plugin]: | ||||
|         try: | ||||
|             _tmp_matchers.set(set()) | ||||
|             _export.set(Export()) | ||||
|             if module_path in plugins: | ||||
|                 return plugins[module_path] | ||||
|             elif module_path in sys.modules: | ||||
|                 logger.warning( | ||||
|                     f"Module {module_path} has been loaded by other plugins! Ignored" | ||||
|                 ) | ||||
|                 return | ||||
|             module = importlib.import_module(module_path) | ||||
|             for m in _tmp_matchers.get(): | ||||
|                 m.module = module_path | ||||
|             plugin = Plugin(module_path, module, _tmp_matchers.get(), | ||||
|                             _export.get()) | ||||
|             plugins[module_path] = plugin | ||||
|             logger.opt( | ||||
|                 colors=True).info(f'Succeeded to import "<y>{module_path}</y>"') | ||||
|             return plugin | ||||
|         except Exception as e: | ||||
|             logger.opt(colors=True, exception=e).error( | ||||
|                 f'<r><bg #f8bbd0>Failed to import "{module_path}"</bg #f8bbd0></r>' | ||||
|             ) | ||||
|             return | ||||
|         module = importlib.import_module(module_path) | ||||
|         for m in _tmp_matchers: | ||||
|             m.module = module_path | ||||
|         plugin = Plugin(module_path, module, _tmp_matchers.copy()) | ||||
|         plugins[module_path] = plugin | ||||
|         logger.opt( | ||||
|             colors=True).info(f'Succeeded to import "<y>{module_path}</y>"') | ||||
|         return plugin | ||||
|     except Exception as e: | ||||
|         logger.opt(colors=True, exception=e).error( | ||||
|             f'<r><bg #f8bbd0>Failed to import "{module_path}"</bg #f8bbd0></r>') | ||||
|         return None | ||||
|             return None | ||||
|  | ||||
|     context: Context = copy_context() | ||||
|     return context.run(_load_plugin, module_path) | ||||
|  | ||||
|  | ||||
| def load_plugins(*plugin_dir: str) -> Set[Plugin]: | ||||
| @@ -419,33 +467,42 @@ def load_plugins(*plugin_dir: str) -> Set[Plugin]: | ||||
|     :返回: | ||||
|       - ``Set[Plugin]`` | ||||
|     """ | ||||
|     loaded_plugins = set() | ||||
|     for module_info in pkgutil.iter_modules(plugin_dir): | ||||
|         _tmp_matchers.clear() | ||||
|  | ||||
|     def _load_plugin(module_info) -> Optional[Plugin]: | ||||
|         _tmp_matchers.set(set()) | ||||
|         _export.set(Export()) | ||||
|         name = module_info.name | ||||
|         if name.startswith("_"): | ||||
|             continue | ||||
|             return | ||||
|  | ||||
|         spec = module_info.module_finder.find_spec(name, None) | ||||
|         if spec.name in plugins: | ||||
|             continue | ||||
|             return | ||||
|         elif spec.name in sys.modules: | ||||
|             logger.warning( | ||||
|                 f"Module {spec.name} has been loaded by other plugin! Ignored") | ||||
|             continue | ||||
|             return | ||||
|  | ||||
|         try: | ||||
|             module = _load(spec) | ||||
|  | ||||
|             for m in _tmp_matchers: | ||||
|             for m in _tmp_matchers.get(): | ||||
|                 m.module = name | ||||
|             plugin = Plugin(name, module, _tmp_matchers.copy()) | ||||
|             plugin = Plugin(name, module, _tmp_matchers.get(), _export.get()) | ||||
|             plugins[name] = plugin | ||||
|             loaded_plugins.add(plugin) | ||||
|             logger.opt(colors=True).info(f'Succeeded to import "<y>{name}</y>"') | ||||
|             return plugin | ||||
|         except Exception as e: | ||||
|             logger.opt(colors=True, exception=e).error( | ||||
|                 f'<r><bg #f8bbd0>Failed to import "{name}"</bg #f8bbd0></r>') | ||||
|             return None | ||||
|  | ||||
|     loaded_plugins = set() | ||||
|     for module_info in pkgutil.iter_modules(plugin_dir): | ||||
|         context: Context = copy_context() | ||||
|         result = context.run(_load_plugin, module_info) | ||||
|         if result: | ||||
|             loaded_plugins.add(result) | ||||
|     return loaded_plugins | ||||
|  | ||||
|  | ||||
| @@ -479,3 +536,12 @@ def get_loaded_plugins() -> Set[Plugin]: | ||||
|       - ``Set[Plugin]`` | ||||
|     """ | ||||
|     return set(plugins.values()) | ||||
|  | ||||
|  | ||||
| def export() -> Export: | ||||
|     return _export.get() | ||||
|  | ||||
|  | ||||
| def require(name: str) -> Optional[Export]: | ||||
|     plugin = get_plugin(name) | ||||
|     return plugin.export if plugin else None | ||||
|   | ||||
		Reference in New Issue
	
	Block a user