Compare commits

..

17 Commits

Author SHA1 Message Date
Akarin~
872be20950 Merge branch 'main' into mod/opti 2025-02-26 00:41:34 +08:00
f5ea844156 修复依赖注入问题? 2025-02-26 00:24:44 +08:00
Akarin~
aa53643aae 更好的缓存,扬掉global,重构代码,整理聊天逻辑 (#16)
* 实现缓存装饰器,优化数据获取和存储逻辑

* 重构代码,准备将聊天请求逻辑移入MarshoHandler

* 记录点(

* unfinished

* 🎨 重写基本完毕

* 移除未使用import,添加漏掉的换行
2025-02-24 01:19:26 +08:00
4fbe6c6366 移除未使用import,添加漏掉的换行 2025-02-24 00:09:39 +08:00
5b315c46b1 🎨 重写基本完毕 2025-02-23 14:50:35 +08:00
d9f22fa0f7 unfinished 2025-02-23 11:37:29 +08:00
091e88fe81 记录点( 2025-02-23 00:23:18 +08:00
Akarin~
5efd753557 Merge branch 'main' into mod/opti 2025-02-22 23:19:54 +08:00
3436390f4b 💫 添加starify 2025-02-22 23:18:22 +08:00
17551885f5 重构代码,准备将聊天请求逻辑移入MarshoHandler 2025-02-22 20:39:03 +08:00
aaa4056482 实现缓存装饰器,优化数据获取和存储逻辑 2025-02-22 13:38:16 +08:00
e1bc81c9e1 pre implement cache 2025-02-22 13:06:06 +08:00
5eb3c66232 Merge branch 'main' of https://github.com/LiteyukiStudio/nonebot-plugin-marshoai 2025-02-17 01:36:23 +08:00
a5e72c6946 修复 lint,忽略F405 2025-02-17 01:35:36 +08:00
金羿ELS
2be57309bd 😋ヾ(≧▽≦*)o让自述文件更美 (#14)
* ヾ(≧▽≦*)o让README更美。

* 真正的美

* 水提交

---------

Co-authored-by: Akarin~ <60691961+Asankilp@users.noreply.github.com>
2025-02-17 01:13:52 +08:00
0b6ac9f73e 修复部分 lint 2025-02-17 01:05:19 +08:00
Akarin~
0e72880167 yaml配置系统重构 (#13)
* 重构模型参数配置,合并为marshoai_model_args字典

* 重构配置管理,移除模板配置文件并实现从ConfigModel读取默认配置并写入

* 修复类型错误
2025-02-15 20:36:10 +08:00
12 changed files with 386 additions and 256 deletions

View File

@@ -10,7 +10,7 @@
_✨ 使用 OpenAI 标准格式 API 的聊天机器人插件 ✨_ _✨ 使用 OpenAI 标准格式 API 的聊天机器人插件 ✨_
[![QQ群](https://img.shields.io/badge/QQ群-1029557452-blue.svg?logo=QQ)](https://qm.qq.com/q/a13iwP5kAw) [![QQ群](https://img.shields.io/badge/QQ群-1029557452-blue.svg?logo=QQ&style=flat-square)](https://qm.qq.com/q/a13iwP5kAw)
[![NoneBot Registry](https://img.shields.io/endpoint?url=https%3A%2F%2Fnbbdg.lgc2333.top%2Fplugin%2Fnonebot-plugin-marshoai&style=flat-square)](https://registry.nonebot.dev/plugin/nonebot-plugin-marshoai:nonebot_plugin_marshoai) [![NoneBot Registry](https://img.shields.io/endpoint?url=https%3A%2F%2Fnbbdg.lgc2333.top%2Fplugin%2Fnonebot-plugin-marshoai&style=flat-square)](https://registry.nonebot.dev/plugin/nonebot-plugin-marshoai:nonebot_plugin_marshoai)
<a href="https://registry.nonebot.dev/plugin/nonebot-plugin-marshoai:nonebot_plugin_marshoai"> <a href="https://registry.nonebot.dev/plugin/nonebot-plugin-marshoai:nonebot_plugin_marshoai">
<img src="https://img.shields.io/endpoint?url=https%3A%2F%2Fnbbdg.lgc2333.top%2Fplugin-adapters%2Fnonebot-plugin-marshoai&style=flat-square" alt="Supported Adapters"> <img src="https://img.shields.io/endpoint?url=https%3A%2F%2Fnbbdg.lgc2333.top%2Fplugin-adapters%2Fnonebot-plugin-marshoai&style=flat-square" alt="Supported Adapters">
@@ -22,20 +22,21 @@ _✨ 使用 OpenAI 标准格式 API 的聊天机器人插件 ✨_
<img src="https://img.shields.io/badge/Code%20Style-Black-121110.svg?style=flat-square" alt="codestyle"> <img src="https://img.shields.io/badge/Code%20Style-Black-121110.svg?style=flat-square" alt="codestyle">
</div> </div>
<img width="100%" src="https://starify.komoridevs.icu/api/starify?owner=LiteyukiStudio&repo=nonebot-plugin-marshoai" alt="starify" />
## 📖 介绍 ## 📖 介绍
通过调用 OpenAI 标准格式 API(例如 GitHub Models API) 来实现聊天的插件。 通过调用 OpenAI 标准格式 API例如 GitHub Models API来实现聊天的插件。
插件内置了猫娘小棉(Marsho)的人物设定,可以进行可爱的聊天! 插件内置了猫娘小棉Marsho,マルショ)的人物设定,可以进行可爱的聊天!
_谁不喜欢回复消息快又可爱的猫娘呢?_ _谁不喜欢回复消息快又可爱的猫娘呢?_
**对 OneBot 以外的适配器与非 GitHub Models API 的支持未经过完全验证。** **对 OneBot 以外的适配器与非 GitHub Models API 的支持未完全经过验证。**
[Melobot 实现](https://github.com/LiteyukiStudio/marshoai-melo) [Melobot 实现](https://github.com/LiteyukiStudio/marshoai-melo)
## 🐱 设定 ## 🐱 设定
#### 基本信息 #### 基本信息
- 名字:小棉(Marsho) - 名字:小棉Marsho,マルショ)
- 生日9 月 6 日 - 生日9 月 6 日
#### 喜好 #### 喜好
@@ -58,7 +59,7 @@ _谁不喜欢回复消息快又可爱的猫娘呢_
- [nonebot-plugin-latex](https://github.com/EillesWan/nonebot-plugin-latex) - [nonebot-plugin-latex](https://github.com/EillesWan/nonebot-plugin-latex)
- [nonebot-plugin-deepseek](https://github.com/KomoriDev/nonebot-plugin-deepseek) - [nonebot-plugin-deepseek](https://github.com/KomoriDev/nonebot-plugin-deepseek)
"Marsho" logo 由 [@Asankilp](https://github.com/Asankilp)绘制,基于 [CC BY-NC-SA 4.0](http://creativecommons.org/licenses/by-nc-sa/4.0/) 许可下提供。 "Marsho" logo 由 [@Asankilp](https://github.com/Asankilp) 绘制,基于 [CC BY-NC-SA 4.0](http://creativecommons.org/licenses/by-nc-sa/4.0/) 许可下提供。
"nonebot-plugin-marshoai" 基于 [MIT](./LICENSE-MIT) 许可下提供。 "nonebot-plugin-marshoai" 基于 [MIT](./LICENSE-MIT) 许可下提供。
部分指定的代码基于 [Mulan PSL v2](./LICENSE-MULAN) 许可下提供。 部分指定的代码基于 [Mulan PSL v2](./LICENSE-MULAN) 许可下提供。

39
nonebot_plugin_marshoai/cache/decos.py vendored Normal file
View File

@@ -0,0 +1,39 @@
from ..models import Cache
cache = Cache()
def from_cache(key):
"""
当缓存中有数据时,直接返回缓存中的数据,否则执行函数并将结果存入缓存
"""
def decorator(func):
async def wrapper(*args, **kwargs):
cached = cache.get(key)
if cached:
return cached
else:
result = await func(*args, **kwargs)
cache.set(key, result)
return result
return wrapper
return decorator
def update_to_cache(key):
"""
执行函数并将结果存入缓存
"""
def decorator(func):
async def wrapper(*args, **kwargs):
result = await func(*args, **kwargs)
cache.set(key, result)
return result
return wrapper
return decorator

View File

@@ -1,4 +1,3 @@
import shutil
from io import StringIO from io import StringIO
from pathlib import Path from pathlib import Path
@@ -81,15 +80,15 @@ destination_folder = Path("config/marshoai/")
destination_file = destination_folder / "config.yaml" destination_file = destination_folder / "config.yaml"
def dump_config_to_yaml(config: ConfigModel): def dump_config_to_yaml(cfg: ConfigModel):
return yaml_.dump(config.model_dump(), allow_unicode=True, default_flow_style=False) return yaml_.dump(cfg.model_dump(), allow_unicode=True, default_flow_style=False)
def write_default_config(destination_file): def write_default_config(dest_file):
""" """
写入默认配置 写入默认配置
""" """
with open(destination_file, "w", encoding="utf-8") as f: with open(dest_file, "w", encoding="utf-8") as f:
with StringIO(dump_config_to_yaml(ConfigModel())) as f2: with StringIO(dump_config_to_yaml(ConfigModel())) as f2:
f.write(f2.read()) f.write(f2.read())
@@ -110,17 +109,17 @@ def check_yaml_is_changed():
return True return True
def merge_configs(old_config, new_config): def merge_configs(existing_cfg, new_cfg):
""" """
合并配置文件 合并配置文件
""" """
for key, value in new_config.items(): for key, value in new_cfg.items():
if key in old_config: if key in existing_cfg:
continue continue
else: else:
logger.info(f"新增配置项: {key} = {value}") logger.info(f"新增配置项: {key} = {value}")
old_config[key] = value existing_cfg[key] = value
return old_config return existing_cfg
config: ConfigModel = get_plugin_config(ConfigModel) config: ConfigModel = get_plugin_config(ConfigModel)

View File

@@ -0,0 +1,242 @@
import json
from typing import Optional, Tuple, Union
from azure.ai.inference.models import (
CompletionsFinishReason,
ImageContentItem,
ImageUrl,
TextContentItem,
ToolMessage,
UserMessage,
)
from nonebot.adapters import Bot, Event
from nonebot.log import logger
from nonebot.matcher import (
Matcher,
current_bot,
current_event,
current_matcher,
)
from nonebot_plugin_alconna.uniseg import UniMessage, UniMsg
from openai import AsyncOpenAI
from openai.types.chat import ChatCompletion, ChatCompletionMessage
from .config import config
from .constants import SUPPORT_IMAGE_MODELS
from .instances import target_list
from .models import MarshoContext
from .plugin.func_call.caller import get_function_calls
from .plugin.func_call.models import SessionContext
from .util import (
extract_content_and_think,
get_image_b64,
get_nickname_by_user_id,
get_prompt,
make_chat_openai,
parse_richtext,
)
class MarshoHandler:
def __init__(
self,
client: AsyncOpenAI,
context: MarshoContext,
):
self.client = client
self.context = context
self.bot: Bot = current_bot.get()
self.event: Event = current_event.get()
# self.state: T_State = current_handler.get().state
self.matcher: Matcher = current_matcher.get()
self.message_id: str = UniMessage.get_message_id(self.event)
self.target = UniMessage.get_target(self.event)
async def process_user_input(
self, user_input: UniMsg, model_name: str
) -> Union[str, list]:
"""
处理用户输入为可输入 API 的格式,并添加昵称提示
"""
is_support_image_model = (
model_name.lower()
in SUPPORT_IMAGE_MODELS + config.marshoai_additional_image_models
)
usermsg = [] if is_support_image_model else ""
user_nickname = await get_nickname_by_user_id(self.event.get_user_id())
if user_nickname:
nickname_prompt = f"\n此消息的说话者为: {user_nickname}"
else:
nickname_prompt = ""
for i in user_input: # type: ignore
if i.type == "text":
if is_support_image_model:
usermsg += [TextContentItem(text=i.data["text"] + nickname_prompt).as_dict()] # type: ignore
else:
usermsg += str(i.data["text"] + nickname_prompt) # type: ignore
elif i.type == "image":
if is_support_image_model:
usermsg.append( # type: ignore
ImageContentItem(
image_url=ImageUrl( # type: ignore
url=str(await get_image_b64(i.data["url"])) # type: ignore
) # type: ignore
).as_dict() # type: ignore
) # type: ignore
logger.info(f"输入图片 {i.data['url']}")
elif config.marshoai_enable_support_image_tip:
await UniMessage(
"*此模型不支持图片处理或管理员未启用此模型的图片支持。图片将被忽略。"
).send()
return usermsg # type: ignore
async def handle_single_chat(
self,
user_message: Union[str, list],
model_name: str,
tools_list: list,
tool_message: Optional[list] = None,
) -> ChatCompletion:
"""
处理单条聊天
"""
context_msg = get_prompt(model_name) + (
self.context.build(self.target.id, self.target.private)
)
response = await make_chat_openai(
client=self.client,
msg=context_msg + [UserMessage(content=user_message).as_dict()] + (tool_message if tool_message else []), # type: ignore
model_name=model_name,
tools=tools_list if tools_list else None,
)
return response
async def handle_function_call(
self,
completion: ChatCompletion,
user_message: Union[str, list],
model_name: str,
tools_list: list,
):
# function call
# 需要获取额外信息,调用函数工具
tool_msg = []
choice = completion.choices[0]
# await UniMessage(str(response)).send()
tool_calls = choice.message.tool_calls
# try:
# if tool_calls[0]["function"]["name"].startswith("$"):
# choice.message.tool_calls[0][
# "type"
# ] = "builtin_function" # 兼容 moonshot AI 内置函数的临时方案
# except:
# pass
tool_msg.append(choice.message)
for tool_call in tool_calls: # type: ignore
try:
function_args = json.loads(tool_call.function.arguments)
except json.JSONDecodeError:
function_args = json.loads(
tool_call.function.arguments.replace("'", '"')
)
# 删除args的placeholder参数
if "placeholder" in function_args:
del function_args["placeholder"]
logger.info(
f"调用函数 {tool_call.function.name.replace('-', '.')}\n参数:"
+ "\n".join([f"{k}={v}" for k, v in function_args.items()])
)
await UniMessage(
f"调用函数 {tool_call.function.name.replace('-', '.')}\n参数:"
+ "\n".join([f"{k}={v}" for k, v in function_args.items()])
).send()
if caller := get_function_calls().get(tool_call.function.name):
logger.debug(f"调用插件函数 {caller.full_name}")
# 权限检查,规则检查 TODO
# 实现依赖注入检查函数参数及参数注解类型对Event类型的参数进行注入
func_return = await caller.with_ctx(
SessionContext(
bot=self.bot,
event=self.event,
matcher=self.matcher,
state=None,
)
).call(**function_args)
else:
logger.error(f"未找到函数 {tool_call.function.name.replace('-', '.')}")
func_return = f"未找到函数 {tool_call.function.name.replace('-', '.')}"
tool_msg.append(
ToolMessage(tool_call_id=tool_call.id, content=func_return).as_dict() # type: ignore
)
# tool_msg[0]["tool_calls"][0]["type"] = "builtin_function"
# await UniMessage(str(tool_msg)).send()
return await self.handle_common_chat(
user_message=user_message,
model_name=model_name,
tools_list=tools_list,
tool_message=tool_msg,
)
async def handle_common_chat(
self,
user_message: Union[str, list],
model_name: str,
tools_list: list,
stream: bool = False,
tool_message: Optional[list] = None,
) -> Optional[Tuple[UserMessage, ChatCompletionMessage]]:
"""
处理一般聊天
"""
global target_list
if stream:
raise NotImplementedError
response = await self.handle_single_chat(
user_message=user_message,
model_name=model_name,
tools_list=tools_list,
tool_message=tool_message,
)
choice = response.choices[0]
# Sprint(choice)
# 当tool_calls非空时将finish_reason设置为TOOL_CALLS
if choice.message.tool_calls is not None and config.marshoai_fix_toolcalls:
choice.finish_reason = "tool_calls"
logger.info(f"完成原因:{choice.finish_reason}")
if choice.finish_reason == CompletionsFinishReason.STOPPED:
##### DeepSeek-R1 兼容部分 #####
choice_msg_content, choice_msg_thinking, choice_msg_after = (
extract_content_and_think(choice.message)
)
if choice_msg_thinking and config.marshoai_send_thinking:
await UniMessage("思维链:\n" + choice_msg_thinking).send()
##### 兼容部分结束 #####
if [self.target.id, self.target.private] not in target_list:
target_list.append([self.target.id, self.target.private])
# 对话成功发送消息
if config.marshoai_enable_richtext_parse:
await (await parse_richtext(str(choice_msg_content))).send(
reply_to=True
)
else:
await UniMessage(str(choice_msg_content)).send(reply_to=True)
return UserMessage(content=user_message), choice_msg_after
elif choice.finish_reason == CompletionsFinishReason.CONTENT_FILTERED:
# 对话失败,消息过滤
await UniMessage("*已被内容过滤器过滤。请调整聊天内容后重试。").send(
reply_to=True
)
return None
elif choice.finish_reason == CompletionsFinishReason.TOOL_CALLS:
return await self.handle_function_call(
response, user_message, model_name, tools_list
)
else:
await UniMessage(f"意外的完成原因:{choice.finish_reason}").send()
return None

View File

@@ -6,7 +6,7 @@ import nonebot_plugin_localstore as store
from nonebot import logger from nonebot import logger
from .config import config from .config import config
from .instances import * from .instances import context, driver, target_list, tools
from .plugin import load_plugin, load_plugins from .plugin import load_plugin, load_plugins
from .util import get_backup_context, save_context_to_json from .util import get_backup_context, save_context_to_json

View File

@@ -1,6 +1,4 @@
# Marsho 的类实例以及全局变量 # Marsho 的类实例以及全局变量
from azure.ai.inference.aio import ChatCompletionsClient
from azure.core.credentials import AzureKeyCredential
from nonebot import get_driver from nonebot import get_driver
from openai import AsyncOpenAI from openai import AsyncOpenAI

View File

@@ -2,15 +2,10 @@ import contextlib
import traceback import traceback
from typing import Optional from typing import Optional
import openai
from arclet.alconna import Alconna, AllParam, Args from arclet.alconna import Alconna, AllParam, Args
from azure.ai.inference.models import ( from azure.ai.inference.models import (
AssistantMessage, AssistantMessage,
CompletionsFinishReason, CompletionsFinishReason,
ImageContentItem,
ImageUrl,
TextContentItem,
ToolMessage,
UserMessage, UserMessage,
) )
from nonebot import logger, on_command, on_message from nonebot import logger, on_command, on_message
@@ -18,15 +13,17 @@ from nonebot.adapters import Bot, Event, Message
from nonebot.matcher import Matcher from nonebot.matcher import Matcher
from nonebot.params import CommandArg from nonebot.params import CommandArg
from nonebot.permission import SUPERUSER from nonebot.permission import SUPERUSER
from nonebot.rule import Rule, to_me from nonebot.rule import to_me
from nonebot.typing import T_State from nonebot.typing import T_State
from nonebot_plugin_alconna import MsgTarget, UniMessage, UniMsg, on_alconna from nonebot_plugin_alconna import MsgTarget, UniMessage, UniMsg, on_alconna
from .config import config
from .constants import INTRODUCTION, SUPPORT_IMAGE_MODELS
from .handler import MarshoHandler
from .hooks import * from .hooks import *
from .instances import * from .instances import client, context, model_name, target_list, tools
from .metadata import metadata from .metadata import metadata
from .plugin.func_call.caller import get_function_calls from .plugin.func_call.caller import get_function_calls
from .plugin.func_call.models import SessionContext
from .util import * from .util import *
@@ -230,16 +227,16 @@ async def marsho(
# 发送说明 # 发送说明
# await UniMessage(metadata.usage + "\n当前使用的模型" + model_name).send() # await UniMessage(metadata.usage + "\n当前使用的模型" + model_name).send()
await marsho_cmd.finish(INTRODUCTION) await marsho_cmd.finish(INTRODUCTION)
backup_context = await get_backup_context(target.id, target.private)
if backup_context:
context.set_context(
backup_context, target.id, target.private
) # 加载历史记录
logger.info(f"已恢复会话 {target.id} 的上下文备份~")
handler = MarshoHandler(client, context)
try: try:
user_id = event.get_user_id() user_nickname = await get_nickname_by_user_id(event.get_user_id())
nicknames = await get_nicknames() if not user_nickname:
user_nickname = nicknames.get(user_id, "")
if user_nickname != "":
nickname_prompt = (
f"\n*此消息的说话者id为:{user_id},名字为:{user_nickname}*"
)
else:
nickname_prompt = ""
# 用户名无法获取,暂时注释 # 用户名无法获取,暂时注释
# user_nickname = event.sender.nickname # 未设置昵称时获取用户名 # user_nickname = event.sender.nickname # 未设置昵称时获取用户名
# nickname_prompt = f"\n*此消息的说话者:{user_nickname}" # nickname_prompt = f"\n*此消息的说话者:{user_nickname}"
@@ -253,189 +250,21 @@ async def marsho(
"※你未设置自己的昵称。推荐使用「nickname [昵称]」命令设置昵称来获得个性化(可能)回答。" "※你未设置自己的昵称。推荐使用「nickname [昵称]」命令设置昵称来获得个性化(可能)回答。"
).send() ).send()
is_support_image_model = ( usermsg = await handler.process_user_input(text, model_name)
model_name.lower()
in SUPPORT_IMAGE_MODELS + config.marshoai_additional_image_models
)
is_openai_new_model = model_name.lower() in OPENAI_NEW_MODELS
usermsg = [] if is_support_image_model else ""
for i in text: # type: ignore
if i.type == "text":
if is_support_image_model:
usermsg += [TextContentItem(text=i.data["text"] + nickname_prompt).as_dict()] # type: ignore
else:
usermsg += str(i.data["text"] + nickname_prompt) # type: ignore
elif i.type == "image":
if is_support_image_model:
usermsg.append( # type: ignore
ImageContentItem(
image_url=ImageUrl( # type: ignore
url=str(await get_image_b64(i.data["url"])) # type: ignore
) # type: ignore
).as_dict() # type: ignore
) # type: ignore
logger.info(f"输入图片 {i.data['url']}")
elif config.marshoai_enable_support_image_tip:
await UniMessage(
"*此模型不支持图片处理或管理员未启用此模型的图片支持。图片将被忽略。"
).send()
backup_context = await get_backup_context(target.id, target.private)
if backup_context:
context.set_context(
backup_context, target.id, target.private
) # 加载历史记录
logger.info(f"已恢复会话 {target.id} 的上下文备份~")
context_msg = get_prompt(model_name) + context.build(target.id, target.private)
tools_lists = tools.tools_list + list( tools_lists = tools.tools_list + list(
map(lambda v: v.data(), get_function_calls().values()) map(lambda v: v.data(), get_function_calls().values())
) )
logger.info(f"正在获取回答,模型:{model_name}") logger.info(f"正在获取回答,模型:{model_name}")
# logger.info(f"上下文:{context_msg}") # logger.info(f"上下文:{context_msg}")
response = await make_chat_openai( response = await handler.handle_common_chat(usermsg, model_name, tools_lists)
client=client,
model_name=model_name,
msg=context_msg + [UserMessage(content=usermsg).as_dict()], # type: ignore
tools=tools_lists if tools_lists else None, # TODO 临时追加函数,后期优化
)
# await UniMessage(str(response)).send() # await UniMessage(str(response)).send()
choice = response.choices[0] if response is not None:
# Sprint(choice) context_user, context_assistant = response
# 当tool_calls非空时将finish_reason设置为TOOL_CALLS context.append(context_user.as_dict(), target.id, target.private)
if choice.message.tool_calls != None and config.marshoai_fix_toolcalls: context.append(context_assistant.to_dict(), target.id, target.private)
choice.finish_reason = "tool_calls"
logger.info(f"完成原因:{choice.finish_reason}")
if choice.finish_reason == CompletionsFinishReason.STOPPED:
# 当对话成功时将dict的上下文添加到上下文类中
context.append(
UserMessage(content=usermsg).as_dict(), target.id, target.private # type: ignore
)
##### DeepSeek-R1 兼容部分 #####
choice_msg_content, choice_msg_thinking, choice_msg_after = (
extract_content_and_think(choice.message)
)
if choice_msg_thinking and config.marshoai_send_thinking:
await UniMessage("思维链:\n" + choice_msg_thinking).send()
##### 兼容部分结束 #####
context.append(choice_msg_after.to_dict(), target.id, target.private)
if [target.id, target.private] not in target_list:
target_list.append([target.id, target.private])
# 对话成功发送消息
if config.marshoai_enable_richtext_parse:
await (await parse_richtext(str(choice_msg_content))).send(
reply_to=True
)
else:
await UniMessage(str(choice_msg_content)).send(reply_to=True)
elif choice.finish_reason == CompletionsFinishReason.CONTENT_FILTERED:
# 对话失败,消息过滤
await UniMessage("*已被内容过滤器过滤。请调整聊天内容后重试。").send(
reply_to=True
)
return
elif choice.finish_reason == CompletionsFinishReason.TOOL_CALLS:
# function call
# 需要获取额外信息,调用函数工具
tool_msg = []
while choice.message.tool_calls != None:
# await UniMessage(str(response)).send()
tool_calls = choice.message.tool_calls
# try:
# if tool_calls[0]["function"]["name"].startswith("$"):
# choice.message.tool_calls[0][
# "type"
# ] = "builtin_function" # 兼容 moonshot AI 内置函数的临时方案
# except:
# pass
tool_msg.append(choice.message)
for tool_call in tool_calls:
try:
function_args = json.loads(tool_call.function.arguments)
except json.JSONDecodeError:
function_args = json.loads(
tool_call.function.arguments.replace("'", '"')
)
# 删除args的placeholder参数
if "placeholder" in function_args:
del function_args["placeholder"]
logger.info(
f"调用函数 {tool_call.function.name.replace('-', '.')}\n参数:"
+ "\n".join([f"{k}={v}" for k, v in function_args.items()])
)
await UniMessage(
f"调用函数 {tool_call.function.name.replace('-', '.')}\n参数:"
+ "\n".join([f"{k}={v}" for k, v in function_args.items()])
).send()
# TODO 临时追加插件函数,若工具中没有则调用插件函数
if tools.has_function(tool_call.function.name):
logger.debug(f"调用工具函数 {tool_call.function.name}")
func_return = await tools.call(
tool_call.function.name, function_args
) # 获取返回值
else:
if caller := get_function_calls().get(tool_call.function.name):
logger.debug(f"调用插件函数 {caller.full_name}")
# 权限检查,规则检查 TODO
# 实现依赖注入检查函数参数及参数注解类型对Event类型的参数进行注入
func_return = await caller.with_ctx(
SessionContext(
bot=bot,
event=event,
state=state,
matcher=matcher,
)
).call(**function_args)
else:
logger.error(
f"未找到函数 {tool_call.function.name.replace('-', '.')}"
)
func_return = f"未找到函数 {tool_call.function.name.replace('-', '.')}"
tool_msg.append(
ToolMessage(tool_call_id=tool_call.id, content=func_return).as_dict() # type: ignore
)
# tool_msg[0]["tool_calls"][0]["type"] = "builtin_function"
# await UniMessage(str(tool_msg)).send()
request_msg = context_msg + [UserMessage(content=usermsg).as_dict()] + tool_msg # type: ignore
response = await make_chat_openai(
client=client,
model_name=model_name,
msg=request_msg, # type: ignore
tools=(
tools_lists if tools_lists else None
), # TODO 临时追加函数,后期优化
)
choice = response.choices[0]
# 当tool_calls非空时将finish_reason设置为TOOL_CALLS
if choice.message.tool_calls != None:
choice.finish_reason = CompletionsFinishReason.TOOL_CALLS
if choice.finish_reason == CompletionsFinishReason.STOPPED:
# 对话成功 添加上下文
context.append(
UserMessage(content=usermsg).as_dict(), target.id, target.private # type: ignore
)
# context.append(tool_msg, target.id, target.private)
choice_msg_dict = choice.message.to_dict()
if "reasoning_content" in choice_msg_dict:
del choice_msg_dict["reasoning_content"]
context.append(choice_msg_dict, target.id, target.private)
# 发送消息
if config.marshoai_enable_richtext_parse:
await (await parse_richtext(str(choice.message.content))).send(
reply_to=True
)
else:
await UniMessage(str(choice.message.content)).send(reply_to=True)
else:
await marsho_cmd.finish(f"意外的完成原因:{choice.finish_reason}")
else: else:
await marsho_cmd.finish(f"意外的完成原因:{choice.finish_reason}") return
except Exception as e: except Exception as e:
await UniMessage(str(e) + suggest_solution(str(e))).send() await UniMessage(str(e) + suggest_solution(str(e))).send()
traceback.print_exc() traceback.print_exc()
@@ -450,12 +279,10 @@ with contextlib.suppress(ImportError): # 优化先不做()
@poke_notify.handle() @poke_notify.handle()
async def poke(event: Event): async def poke(event: Event):
user_id = event.get_user_id() user_nickname = await get_nickname_by_user_id(event.get_user_id())
nicknames = await get_nicknames()
user_nickname = nicknames.get(user_id, "")
try: try:
if config.marshoai_poke_suffix != "": if config.marshoai_poke_suffix != "":
logger.info(f"收到戳一戳,用户昵称:{user_nickname}用户ID{user_id}") logger.info(f"收到戳一戳,用户昵称:{user_nickname}")
response = await make_chat_openai( response = await make_chat_openai(
client=client, client=client,
model_name=model_name, model_name=model_name,

View File

@@ -9,7 +9,25 @@ import traceback
from nonebot import logger from nonebot import logger
from .config import config from .config import config
from .util import *
class Cache:
"""
缓存类
"""
def __init__(self):
self.cache = {}
def get(self, key):
if key in self.cache:
return self.cache[key]
else:
self.cache[key] = None
return None
def set(self, key, value):
self.cache[key] = value
class MarshoContext: class MarshoContext:

View File

@@ -70,8 +70,8 @@ class Caller:
): ):
return False, "告诉用户 Permission Denied 权限不足" return False, "告诉用户 Permission Denied 权限不足"
if self.ctx.state is None: # if self.ctx.state is None:
return False, "State is None" # return False, "State is None"
if self._rule and not await self._rule( if self._rule and not await self._rule(
self.ctx.bot, self.ctx.event, self.ctx.state self.ctx.bot, self.ctx.event, self.ctx.state
): ):
@@ -115,6 +115,10 @@ class Caller:
# 检查函数签名,确定依赖注入参数 # 检查函数签名,确定依赖注入参数
sig = inspect.signature(func) sig = inspect.signature(func)
for name, param in sig.parameters.items(): for name, param in sig.parameters.items():
# if param.annotation == T_State:
# self.di.state = name
# continue # 防止后续判断T_State子类时报错
if issubclass(param.annotation, Event) or isinstance( if issubclass(param.annotation, Event) or isinstance(
param.annotation, Event param.annotation, Event
): ):
@@ -133,9 +137,6 @@ class Caller:
): ):
self.di.matcher = name self.di.matcher = name
if param.annotation == T_State:
self.di.state = name
# 检查默认值情况 # 检查默认值情况
for name, param in sig.parameters.items(): for name, param in sig.parameters.items():
if param.default is not inspect.Parameter.empty: if param.default is not inspect.Parameter.empty:

View File

@@ -19,7 +19,7 @@ class SessionContext(BaseModel):
bot: Bot bot: Bot
event: Event event: Event
matcher: Matcher matcher: Matcher
state: T_State state: T_State | None
caller: Any = None caller: Any = None
class Config: class Config:
@@ -30,5 +30,5 @@ class SessionContextDepends(BaseModel):
bot: str | None = None bot: str | None = None
event: str | None = None event: str | None = None
matcher: str | None = None matcher: str | None = None
state: str | None = None # state: str | None = None
caller: str | None = None caller: str | None = None

View File

@@ -20,13 +20,14 @@ from openai.types.chat import ChatCompletion, ChatCompletionMessage
from zhDateTime import DateTime from zhDateTime import DateTime
from ._types import DeveloperMessage from ._types import DeveloperMessage
from .cache.decos import *
from .config import config from .config import config
from .constants import * from .constants import CODE_BLOCK_PATTERN, IMG_LATEX_PATTERN, OPENAI_NEW_MODELS
from .deal_latex import ConvertLatex from .deal_latex import ConvertLatex
nickname_json = None # 记录昵称 # nickname_json = None # 记录昵称
praises_json = None # 记录夸赞名单 # praises_json = None # 记录夸赞名单
loaded_target_list = [] # 记录已恢复备份的上下文的列表 loaded_target_list: List[str] = [] # 记录已恢复备份的上下文的列表
NOT_GIVEN = NotGiven() NOT_GIVEN = NotGiven()
@@ -155,30 +156,29 @@ async def make_chat_openai(
) )
@from_cache("praises")
def get_praises(): def get_praises():
global praises_json praises_file = store.get_plugin_data_file(
if praises_json is None: "praises.json"
praises_file = store.get_plugin_data_file( ) # 夸赞名单文件使用localstore存储
"praises.json" if not praises_file.exists():
) # 夸赞名单文件使用localstore存储 with open(praises_file, "w", encoding="utf-8") as f:
if not praises_file.exists(): json.dump(_praises_init_data, f, ensure_ascii=False, indent=4)
with open(praises_file, "w", encoding="utf-8") as f: with open(praises_file, "r", encoding="utf-8") as f:
json.dump(_praises_init_data, f, ensure_ascii=False, indent=4) data = json.load(f)
with open(praises_file, "r", encoding="utf-8") as f: praises_json = data
data = json.load(f)
praises_json = data
return praises_json return praises_json
@update_to_cache("praises")
async def refresh_praises_json(): async def refresh_praises_json():
global praises_json
praises_file = store.get_plugin_data_file("praises.json") praises_file = store.get_plugin_data_file("praises.json")
if not praises_file.exists(): if not praises_file.exists():
with open(praises_file, "w", encoding="utf-8") as f: with open(praises_file, "w", encoding="utf-8") as f:
json.dump(_praises_init_data, f, ensure_ascii=False, indent=4) # 异步? json.dump(_praises_init_data, f, ensure_ascii=False, indent=4) # 异步?
async with aiofiles.open(praises_file, "r", encoding="utf-8") as f: async with aiofiles.open(praises_file, "r", encoding="utf-8") as f:
data = json.loads(await f.read()) data = json.loads(await f.read())
praises_json = data return data
def build_praises() -> str: def build_praises() -> str:
@@ -210,22 +210,21 @@ async def load_context_from_json(name: str, path: str) -> list:
return [] return []
@from_cache("nickname")
async def get_nicknames(): async def get_nicknames():
"""获取nickname_json, 优先来源于全局变量""" """获取nickname_json, 优先来源于缓存"""
global nickname_json filename = store.get_plugin_data_file("nickname.json")
if nickname_json is None: # noinspection PyBroadException
filename = store.get_plugin_data_file("nickname.json") try:
# noinspection PyBroadException async with aiofiles.open(filename, "r", encoding="utf-8") as f:
try: nickname_json = json.loads(await f.read())
async with aiofiles.open(filename, "r", encoding="utf-8") as f: except (json.JSONDecodeError, FileNotFoundError):
nickname_json = json.loads(await f.read()) nickname_json = {}
except Exception:
nickname_json = {}
return nickname_json return nickname_json
@update_to_cache("nickname")
async def set_nickname(user_id: str, name: str): async def set_nickname(user_id: str, name: str):
global nickname_json
filename = store.get_plugin_data_file("nickname.json") filename = store.get_plugin_data_file("nickname.json")
if not filename.exists(): if not filename.exists():
data = {} data = {}
@@ -237,19 +236,25 @@ async def set_nickname(user_id: str, name: str):
del data[user_id] del data[user_id]
with open(filename, "w", encoding="utf-8") as f: with open(filename, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=4) json.dump(data, f, ensure_ascii=False, indent=4)
nickname_json = data return data
async def get_nickname_by_user_id(user_id: str):
nickname_json = await get_nicknames()
return nickname_json.get(user_id, "")
@update_to_cache("nickname")
async def refresh_nickname_json(): async def refresh_nickname_json():
"""强制刷新nickname_json, 刷新全局变量""" """强制刷新nickname_json"""
global nickname_json
# noinspection PyBroadException # noinspection PyBroadException
try: try:
async with aiofiles.open( async with aiofiles.open(
store.get_plugin_data_file("nickname.json"), "r", encoding="utf-8" store.get_plugin_data_file("nickname.json"), "r", encoding="utf-8"
) as f: ) as f:
nickname_json = json.loads(await f.read()) nickname_json = json.loads(await f.read())
except Exception: return nickname_json
except (json.JSONDecodeError, FileNotFoundError):
logger.error("刷新 nickname_json 表错误:无法载入 nickname.json 文件") logger.error("刷新 nickname_json 表错误:无法载入 nickname.json 文件")

View File

@@ -81,4 +81,4 @@ test = [
] ]
[tool.ruff.lint] [tool.ruff.lint]
ignore = ["E402"] ignore = ["E402", "F405"]