mirror of
				https://github.com/nonebot/nonebot2.git
				synced 2025-10-26 20:46:39 +00:00 
			
		
		
		
	⚗️ add export require option
This commit is contained in:
		| @@ -240,4 +240,4 @@ async def _start_scheduler(): | |||||||
| from nonebot.plugin import on_message, on_notice, on_request, on_metaevent, CommandGroup | from nonebot.plugin import on_message, on_notice, on_request, on_metaevent, CommandGroup | ||||||
| from nonebot.plugin import on_startswith, on_endswith, on_keyword, on_command, on_regex | from nonebot.plugin import on_startswith, on_endswith, on_keyword, on_command, on_regex | ||||||
| from nonebot.plugin import load_plugin, load_plugins, load_builtin_plugins | from nonebot.plugin import load_plugin, load_plugins, load_builtin_plugins | ||||||
| from nonebot.plugin import get_plugin, get_loaded_plugins | from nonebot.plugin import export, require, get_plugin, get_loaded_plugins | ||||||
|   | |||||||
| @@ -11,6 +11,7 @@ import pkgutil | |||||||
| import importlib | import importlib | ||||||
| from dataclasses import dataclass | from dataclasses import dataclass | ||||||
| from importlib._bootstrap import _load | from importlib._bootstrap import _load | ||||||
|  | from contextvars import Context, ContextVar, copy_context | ||||||
|  |  | ||||||
| from nonebot.log import logger | from nonebot.log import logger | ||||||
| from nonebot.matcher import Matcher | from nonebot.matcher import Matcher | ||||||
| @@ -25,7 +26,45 @@ plugins: Dict[str, "Plugin"] = {} | |||||||
| :说明: 已加载的插件 | :说明: 已加载的插件 | ||||||
| """ | """ | ||||||
|  |  | ||||||
| _tmp_matchers: Set[Type[Matcher]] = set() | _tmp_matchers: ContextVar[Set[Type[Matcher]]] = ContextVar("_tmp_matchers") | ||||||
|  | _export: ContextVar["Export"] = ContextVar("_export") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Export(dict): | ||||||
|  |     """ | ||||||
|  |     :说明: | ||||||
|  |       插件导出内容以使得其他插件可以获得。 | ||||||
|  |     :示例: | ||||||
|  |  | ||||||
|  |     .. code-block:: python | ||||||
|  |  | ||||||
|  |         nonebot.export().default = "bar" | ||||||
|  |  | ||||||
|  |         @nonebot.export() | ||||||
|  |         def some_function(): | ||||||
|  |             pass | ||||||
|  |  | ||||||
|  |         @nonebot.export().sub | ||||||
|  |         def something_else(): | ||||||
|  |             pass | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     def __call__(self, func, **kwargs): | ||||||
|  |         self[func.__name__] = func | ||||||
|  |         self.update(kwargs) | ||||||
|  |         return func | ||||||
|  |  | ||||||
|  |     def __setitem__(self, key, value): | ||||||
|  |         super().__setitem__(key, | ||||||
|  |                             Export(value) if isinstance(value, dict) else value) | ||||||
|  |  | ||||||
|  |     def __setattr__(self, name, value): | ||||||
|  |         self[name] = Export(value) if isinstance(value, dict) else value | ||||||
|  |  | ||||||
|  |     def __getattr__(self, name): | ||||||
|  |         if name not in self: | ||||||
|  |             self[name] = Export() | ||||||
|  |         return self[name] | ||||||
|  |  | ||||||
|  |  | ||||||
| @dataclass(eq=False) | @dataclass(eq=False) | ||||||
| @@ -46,6 +85,7 @@ class Plugin(object): | |||||||
|     - **类型**: ``Set[Type[Matcher]]`` |     - **类型**: ``Set[Type[Matcher]]`` | ||||||
|     - **说明**: 插件内定义的 ``Matcher`` |     - **说明**: 插件内定义的 ``Matcher`` | ||||||
|     """ |     """ | ||||||
|  |     export: Export | ||||||
|  |  | ||||||
|  |  | ||||||
| def on(type: str = "", | def on(type: str = "", | ||||||
| @@ -80,7 +120,7 @@ def on(type: str = "", | |||||||
|                           block=block, |                           block=block, | ||||||
|                           handlers=handlers, |                           handlers=handlers, | ||||||
|                           default_state=state) |                           default_state=state) | ||||||
|     _tmp_matchers.add(matcher) |     _tmp_matchers.get().add(matcher) | ||||||
|     return matcher |     return matcher | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -112,7 +152,7 @@ def on_metaevent(rule: Optional[Union[Rule, RuleChecker]] = None, | |||||||
|                           block=block, |                           block=block, | ||||||
|                           handlers=handlers, |                           handlers=handlers, | ||||||
|                           default_state=state) |                           default_state=state) | ||||||
|     _tmp_matchers.add(matcher) |     _tmp_matchers.get().add(matcher) | ||||||
|     return matcher |     return matcher | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -146,7 +186,7 @@ def on_message(rule: Optional[Union[Rule, RuleChecker]] = None, | |||||||
|                           block=block, |                           block=block, | ||||||
|                           handlers=handlers, |                           handlers=handlers, | ||||||
|                           default_state=state) |                           default_state=state) | ||||||
|     _tmp_matchers.add(matcher) |     _tmp_matchers.get().add(matcher) | ||||||
|     return matcher |     return matcher | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -178,7 +218,7 @@ def on_notice(rule: Optional[Union[Rule, RuleChecker]] = None, | |||||||
|                           block=block, |                           block=block, | ||||||
|                           handlers=handlers, |                           handlers=handlers, | ||||||
|                           default_state=state) |                           default_state=state) | ||||||
|     _tmp_matchers.add(matcher) |     _tmp_matchers.get().add(matcher) | ||||||
|     return matcher |     return matcher | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -210,7 +250,7 @@ def on_request(rule: Optional[Union[Rule, RuleChecker]] = None, | |||||||
|                           block=block, |                           block=block, | ||||||
|                           handlers=handlers, |                           handlers=handlers, | ||||||
|                           default_state=state) |                           default_state=state) | ||||||
|     _tmp_matchers.add(matcher) |     _tmp_matchers.get().add(matcher) | ||||||
|     return matcher |     return matcher | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -387,8 +427,11 @@ def load_plugin(module_path: str) -> Optional[Plugin]: | |||||||
|     :返回: |     :返回: | ||||||
|       - ``Optional[Plugin]`` |       - ``Optional[Plugin]`` | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|  |     def _load_plugin(module_path: str) -> Optional[Plugin]: | ||||||
|         try: |         try: | ||||||
|         _tmp_matchers.clear() |             _tmp_matchers.set(set()) | ||||||
|  |             _export.set(Export()) | ||||||
|             if module_path in plugins: |             if module_path in plugins: | ||||||
|                 return plugins[module_path] |                 return plugins[module_path] | ||||||
|             elif module_path in sys.modules: |             elif module_path in sys.modules: | ||||||
| @@ -397,18 +440,23 @@ def load_plugin(module_path: str) -> Optional[Plugin]: | |||||||
|                 ) |                 ) | ||||||
|                 return |                 return | ||||||
|             module = importlib.import_module(module_path) |             module = importlib.import_module(module_path) | ||||||
|         for m in _tmp_matchers: |             for m in _tmp_matchers.get(): | ||||||
|                 m.module = module_path |                 m.module = module_path | ||||||
|         plugin = Plugin(module_path, module, _tmp_matchers.copy()) |             plugin = Plugin(module_path, module, _tmp_matchers.get(), | ||||||
|  |                             _export.get()) | ||||||
|             plugins[module_path] = plugin |             plugins[module_path] = plugin | ||||||
|             logger.opt( |             logger.opt( | ||||||
|                 colors=True).info(f'Succeeded to import "<y>{module_path}</y>"') |                 colors=True).info(f'Succeeded to import "<y>{module_path}</y>"') | ||||||
|             return plugin |             return plugin | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             logger.opt(colors=True, exception=e).error( |             logger.opt(colors=True, exception=e).error( | ||||||
|             f'<r><bg #f8bbd0>Failed to import "{module_path}"</bg #f8bbd0></r>') |                 f'<r><bg #f8bbd0>Failed to import "{module_path}"</bg #f8bbd0></r>' | ||||||
|  |             ) | ||||||
|             return None |             return None | ||||||
|  |  | ||||||
|  |     context: Context = copy_context() | ||||||
|  |     return context.run(_load_plugin, module_path) | ||||||
|  |  | ||||||
|  |  | ||||||
| def load_plugins(*plugin_dir: str) -> Set[Plugin]: | def load_plugins(*plugin_dir: str) -> Set[Plugin]: | ||||||
|     """ |     """ | ||||||
| @@ -419,33 +467,42 @@ def load_plugins(*plugin_dir: str) -> Set[Plugin]: | |||||||
|     :返回: |     :返回: | ||||||
|       - ``Set[Plugin]`` |       - ``Set[Plugin]`` | ||||||
|     """ |     """ | ||||||
|     loaded_plugins = set() |  | ||||||
|     for module_info in pkgutil.iter_modules(plugin_dir): |     def _load_plugin(module_info) -> Optional[Plugin]: | ||||||
|         _tmp_matchers.clear() |         _tmp_matchers.set(set()) | ||||||
|  |         _export.set(Export()) | ||||||
|         name = module_info.name |         name = module_info.name | ||||||
|         if name.startswith("_"): |         if name.startswith("_"): | ||||||
|             continue |             return | ||||||
|  |  | ||||||
|         spec = module_info.module_finder.find_spec(name, None) |         spec = module_info.module_finder.find_spec(name, None) | ||||||
|         if spec.name in plugins: |         if spec.name in plugins: | ||||||
|             continue |             return | ||||||
|         elif spec.name in sys.modules: |         elif spec.name in sys.modules: | ||||||
|             logger.warning( |             logger.warning( | ||||||
|                 f"Module {spec.name} has been loaded by other plugin! Ignored") |                 f"Module {spec.name} has been loaded by other plugin! Ignored") | ||||||
|             continue |             return | ||||||
|  |  | ||||||
|         try: |         try: | ||||||
|             module = _load(spec) |             module = _load(spec) | ||||||
|  |  | ||||||
|             for m in _tmp_matchers: |             for m in _tmp_matchers.get(): | ||||||
|                 m.module = name |                 m.module = name | ||||||
|             plugin = Plugin(name, module, _tmp_matchers.copy()) |             plugin = Plugin(name, module, _tmp_matchers.get(), _export.get()) | ||||||
|             plugins[name] = plugin |             plugins[name] = plugin | ||||||
|             loaded_plugins.add(plugin) |  | ||||||
|             logger.opt(colors=True).info(f'Succeeded to import "<y>{name}</y>"') |             logger.opt(colors=True).info(f'Succeeded to import "<y>{name}</y>"') | ||||||
|  |             return plugin | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             logger.opt(colors=True, exception=e).error( |             logger.opt(colors=True, exception=e).error( | ||||||
|                 f'<r><bg #f8bbd0>Failed to import "{name}"</bg #f8bbd0></r>') |                 f'<r><bg #f8bbd0>Failed to import "{name}"</bg #f8bbd0></r>') | ||||||
|  |             return None | ||||||
|  |  | ||||||
|  |     loaded_plugins = set() | ||||||
|  |     for module_info in pkgutil.iter_modules(plugin_dir): | ||||||
|  |         context: Context = copy_context() | ||||||
|  |         result = context.run(_load_plugin, module_info) | ||||||
|  |         if result: | ||||||
|  |             loaded_plugins.add(result) | ||||||
|     return loaded_plugins |     return loaded_plugins | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -479,3 +536,12 @@ def get_loaded_plugins() -> Set[Plugin]: | |||||||
|       - ``Set[Plugin]`` |       - ``Set[Plugin]`` | ||||||
|     """ |     """ | ||||||
|     return set(plugins.values()) |     return set(plugins.values()) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def export() -> Export: | ||||||
|  |     return _export.get() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def require(name: str) -> Optional[Export]: | ||||||
|  |     plugin = get_plugin(name) | ||||||
|  |     return plugin.export if plugin else None | ||||||
|   | |||||||
| @@ -1,17 +1,32 @@ | |||||||
| import re | import re | ||||||
|  | from contextvars import ContextVar | ||||||
|  |  | ||||||
| from nonebot.typing import Rule, Matcher, Handler, Permission, RuleChecker | from nonebot.typing import Rule, Matcher, Handler, Permission, RuleChecker | ||||||
| from nonebot.typing import Set, List, Dict, Type, Tuple, Union, Optional, ModuleType | from nonebot.typing import Set, List, Dict, Type, Tuple, Union, Optional, ModuleType | ||||||
|  |  | ||||||
| plugins: Dict[str, "Plugin"] = ... | plugins: Dict[str, "Plugin"] = ... | ||||||
|  |  | ||||||
| _tmp_matchers: Set[Type[Matcher]] = ... | _tmp_matchers: ContextVar[Set[Type[Matcher]]] = ... | ||||||
|  | _export: ContextVar["Export"] = ... | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Export(dict): | ||||||
|  |  | ||||||
|  |     def __call__(self, func, **kwargs): | ||||||
|  |         ... | ||||||
|  |  | ||||||
|  |     def __setattr__(self, name, value): | ||||||
|  |         ... | ||||||
|  |  | ||||||
|  |     def __getattr__(self, name): | ||||||
|  |         ... | ||||||
|  |  | ||||||
|  |  | ||||||
| class Plugin(object): | class Plugin(object): | ||||||
|     name: str |     name: str | ||||||
|     module: ModuleType |     module: ModuleType | ||||||
|     matcher: Set[Type[Matcher]] |     matcher: Set[Type[Matcher]] | ||||||
|  |     export: Export | ||||||
|  |  | ||||||
|  |  | ||||||
| def on(type: str = ..., | def on(type: str = ..., | ||||||
| @@ -149,6 +164,14 @@ def get_loaded_plugins() -> Set[Plugin]: | |||||||
|     ... |     ... | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def export() -> Export: | ||||||
|  |     ... | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def require(name: str) -> Export: | ||||||
|  |     ... | ||||||
|  |  | ||||||
|  |  | ||||||
| class CommandGroup: | class CommandGroup: | ||||||
|  |  | ||||||
|     def __init__(self, |     def __init__(self, | ||||||
|   | |||||||
| @@ -9,6 +9,8 @@ sidebar: auto | |||||||
| - 修复 cqhttp 检查 to me 时出现 IndexError | - 修复 cqhttp 检查 to me 时出现 IndexError | ||||||
| - 修复已失效的事件响应器仍会运行一次的 bug | - 修复已失效的事件响应器仍会运行一次的 bug | ||||||
| - 修改 cqhttp 检查 reply 时未去除后续 at 以及空格 | - 修改 cqhttp 检查 reply 时未去除后续 at 以及空格 | ||||||
|  | - 添加 get_plugin 获取插件函数 | ||||||
|  | - 添加插件 export, require 方法 | ||||||
|  |  | ||||||
| ## v2.0.0a6 | ## v2.0.0a6 | ||||||
|  |  | ||||||
|   | |||||||
| @@ -22,6 +22,8 @@ nonebot.load_builtin_plugins() | |||||||
| # load local plugins | # load local plugins | ||||||
| nonebot.load_plugins("test_plugins") | nonebot.load_plugins("test_plugins") | ||||||
|  |  | ||||||
|  | print(nonebot.require("test_export")) | ||||||
|  |  | ||||||
| # modify some config / config depends on loaded configs | # modify some config / config depends on loaded configs | ||||||
| config = nonebot.get_driver().config | config = nonebot.get_driver().config | ||||||
| config.custom_config3 = config.custom_config1 | config.custom_config3 = config.custom_config1 | ||||||
|   | |||||||
							
								
								
									
										15
									
								
								tests/test_plugins/test_export.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										15
									
								
								tests/test_plugins/test_export.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,15 @@ | |||||||
|  | import nonebot | ||||||
|  |  | ||||||
|  | export = nonebot.export() | ||||||
|  | export.foo = "bar" | ||||||
|  | export["bar"] = "foo" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @export | ||||||
|  | def a(): | ||||||
|  |     pass | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @export.sub | ||||||
|  | def b(): | ||||||
|  |     pass | ||||||
		Reference in New Issue
	
	Block a user