🐛 fix error matcher module when import

This commit is contained in:
yanyongyu
2021-03-22 01:15:15 +08:00
parent d738f8674d
commit 6371cd6bfe
2 changed files with 30 additions and 13 deletions

View File

@ -5,10 +5,14 @@ import importlib
from hashlib import md5
from types import ModuleType
from collections import Counter
from contextvars import ContextVar
from importlib.abc import MetaPathFinder
from importlib.machinery import PathFinder, SourceFileLoader
from importlib.machinery import PathFinder, FrozenImporter, SourceFileLoader
from typing import Set, List, Optional, Iterable
_current_plugin: ContextVar[Optional[str]] = ContextVar("_current_plugin",
default=None)
_internal_space = ModuleType(__name__ + "._internal")
_internal_space.__path__ = [] # type: ignore
sys.modules[_internal_space.__name__] = _internal_space
@ -138,7 +142,8 @@ class PluginManager:
def load_plugin(self, name) -> ModuleType:
if name in self.plugins:
return importlib.import_module(name)
with self:
return importlib.import_module(name)
if "." in name:
raise ValueError("Plugin name cannot contain '.'")
@ -150,14 +155,15 @@ class PluginManager:
return [self.load_plugin(name) for name in self.list_plugins()]
def _rewrite_module_name(self, module_name) -> Optional[str]:
if module_name == self.namespace:
return self.internal_module.__name__
elif module_name.startswith(self.namespace + "."):
prefix = f"{self.internal_module.__name__}."
if module_name.startswith(self.namespace + "."):
path = module_name.split(".")
length = self.namespace.count(".") + 1
return f"{self.internal_module.__name__}.{'.'.join(path[length:])}"
return f"{prefix}{'.'.join(path[length:])}"
elif module_name in self.search_plugins():
return f"{self.internal_module.__name__}.{module_name}"
return f"{prefix}{module_name}"
elif module_name in self.plugins or module_name.startswith(prefix):
return module_name
return None
@ -170,9 +176,8 @@ class PluginFinder(MetaPathFinder):
manager = _manager_stack[index]
newname = manager._rewrite_module_name(fullname)
if newname:
spec = PathFinder.find_spec(newname,
list(manager.search_path),
target)
spec = PathFinder.find_spec(
newname, [*manager.search_path, *(path or [])], target)
if spec:
spec.loader = PluginLoader(manager, newname,
spec.origin)
@ -186,12 +191,17 @@ class PluginLoader(SourceFileLoader):
def __init__(self, manager: PluginManager, fullname: str, path) -> None:
self.manager = manager
self.loaded = False
self._context_token = None
super().__init__(fullname, path)
def create_module(self, spec) -> Optional[ModuleType]:
if self.name in sys.modules:
self.loaded = True
return sys.modules[self.name]
prefix = self.manager.internal_module.__name__
plugin_name = self.name[len(prefix):] if self.name.startswith(
prefix) else self.name
self._context_token = _current_plugin.set(plugin_name.lstrip("."))
# return None to use default module creation
return super().create_module(spec)
@ -200,6 +210,8 @@ class PluginLoader(SourceFileLoader):
return
setattr(module, "__manager__", self.manager)
super().exec_module(module)
if self._context_token:
_current_plugin.reset(self._context_token)
return