mirror of
https://github.com/LiteyukiStudio/nonebot-plugin-marshoai.git
synced 2025-08-01 07:09:50 +00:00
✨ 重构Caller类,移除泛型参数;添加函数签名复制装饰器
This commit is contained in:
@ -1,36 +1,34 @@
|
||||
from typing import Generic, TypeVar
|
||||
import inspect
|
||||
from typing import Any
|
||||
|
||||
from nonebot import logger
|
||||
from nonebot.adapters import Event
|
||||
|
||||
from ..typing import FUNCTION_CALL_FUNC
|
||||
from .params import P
|
||||
from ..typing import ASYNC_FUNCTION_CALL_FUNC, F
|
||||
from .utils import async_wrap, is_coroutine_callable
|
||||
|
||||
F = TypeVar("F", bound=FUNCTION_CALL_FUNC)
|
||||
_caller_data: dict[str, "Caller"] = {}
|
||||
|
||||
|
||||
class Caller(Generic[P]):
|
||||
class Caller:
|
||||
def __init__(self, name: str | None = None, description: str | None = None):
|
||||
self._name = name
|
||||
self._description = description
|
||||
self._parameters: dict[str, P] = {}
|
||||
self.func: FUNCTION_CALL_FUNC | None = None
|
||||
self.func: ASYNC_FUNCTION_CALL_FUNC | None = None
|
||||
self._parameters: dict[str, Any] = {}
|
||||
"""依赖注入的参数"""
|
||||
self.event: Event | None = None
|
||||
|
||||
def params(self, **kwargs: P) -> "Caller":
|
||||
"""设置多个函数参数
|
||||
Args:
|
||||
**kwargs: 参数字典
|
||||
Returns:
|
||||
Caller: Caller对象
|
||||
"""
|
||||
def params(self, **kwargs: Any) -> "Caller":
|
||||
self._parameters.update(kwargs)
|
||||
return self
|
||||
|
||||
def param(self, name: str, param: P) -> "Caller":
|
||||
def param(self, name: str, param: Any) -> "Caller":
|
||||
"""设置一个函数参数
|
||||
|
||||
Args:
|
||||
name (str): 参数名
|
||||
param (P): 参数对象
|
||||
param (Any): 参数对象
|
||||
|
||||
Returns:
|
||||
Caller: Caller对象
|
||||
@ -51,14 +49,6 @@ class Caller(Generic[P]):
|
||||
return self
|
||||
|
||||
def description(self, description: str) -> "Caller":
|
||||
"""设置函数描述
|
||||
|
||||
Args:
|
||||
description (str): 函数描述
|
||||
|
||||
Returns:
|
||||
Caller: Caller对象
|
||||
"""
|
||||
self._description = description
|
||||
return self
|
||||
|
||||
@ -71,12 +61,78 @@ class Caller(Generic[P]):
|
||||
Returns:
|
||||
F: 函数对象
|
||||
"""
|
||||
global _caller_data
|
||||
if self._name is None:
|
||||
if module := inspect.getmodule(func):
|
||||
module_name = module.__name__.split(".")[-1]
|
||||
else:
|
||||
module_name = "global"
|
||||
self._name = f"{module_name}-{func.__name__}"
|
||||
_caller_data[self._name] = self
|
||||
|
||||
if is_coroutine_callable(func):
|
||||
self.func = func # type: ignore
|
||||
else:
|
||||
self.func = async_wrap(func) # type: ignore
|
||||
|
||||
if module := inspect.getmodule(func):
|
||||
module_name = module.__name__ + "."
|
||||
else:
|
||||
module_name = ""
|
||||
logger.opt(colors=True).info(
|
||||
f"<y>加载函数 {func.__name__} {self._description}</y>"
|
||||
f"<y>加载函数 {module_name}{func.__name__}: {self._description}</y>"
|
||||
)
|
||||
self.func = func
|
||||
|
||||
return func
|
||||
|
||||
def data(self) -> dict[str, Any]:
|
||||
"""返回函数的json数据
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: 函数的json数据
|
||||
"""
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self._name,
|
||||
"description": self._description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
key: value.data() for key, value in self._parameters.items()
|
||||
},
|
||||
},
|
||||
"required": [
|
||||
key
|
||||
for key, value in self._parameters.items()
|
||||
if value.default is None
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
def set_event(self, event: Event):
|
||||
self.event = event
|
||||
|
||||
async def call(self, *args: Any, **kwargs: Any) -> Any:
|
||||
"""调用函数
|
||||
|
||||
Returns:
|
||||
Any: 函数返回值
|
||||
"""
|
||||
if self.func is None:
|
||||
raise ValueError("未注册函数对象")
|
||||
sig = inspect.signature(self.func)
|
||||
for name, param in sig.parameters.items():
|
||||
if issubclass(param.annotation, Event) or isinstance(
|
||||
param.annotation, Event
|
||||
):
|
||||
kwargs[name] = self.event
|
||||
if issubclass(param.annotation, Caller) or isinstance(
|
||||
param.annotation, Caller
|
||||
):
|
||||
kwargs[name] = self
|
||||
return await self.func(*args, **kwargs)
|
||||
|
||||
|
||||
def on_function_call(name: str | None = None, description: str | None = None) -> Caller:
|
||||
"""返回一个Caller类,可用于装饰一个函数,使其注册为一个可被AI调用的function call函数
|
||||
@ -87,5 +143,14 @@ def on_function_call(name: str | None = None, description: str | None = None) ->
|
||||
Returns:
|
||||
Caller: Caller对象
|
||||
"""
|
||||
caller = Caller(name=name, description=description)
|
||||
return caller
|
||||
|
||||
return Caller(name=name, description=description)
|
||||
|
||||
def get_function_calls() -> dict[str, Caller]:
|
||||
"""获取所有已注册的function call函数
|
||||
|
||||
Returns:
|
||||
dict[str, Caller]: 所有已注册的function call函数
|
||||
"""
|
||||
return _caller_data
|
||||
|
Reference in New Issue
Block a user