From 5768b95b096bed2e333a90ccf8628b9ace0d9dc0 Mon Sep 17 00:00:00 2001 From: Akarin~ <60691961+Asankilp@users.noreply.github.com> Date: Fri, 4 Apr 2025 23:01:01 +0800 Subject: [PATCH] =?UTF-8?q?[WIP]=20=E8=A1=A8=E6=83=85=E5=9B=9E=E5=BA=94?= =?UTF-8?q?=E6=94=AF=E6=8C=81=20(#26)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 初步支持&utils重构 * 戳一戳支持流式请求 * 移除未使用import * 解决类型问题 --- nonebot_plugin_marshoai/handler.py | 74 ++++-------------------- nonebot_plugin_marshoai/marsho.py | 30 +++++++--- nonebot_plugin_marshoai/util.py | 2 +- nonebot_plugin_marshoai/utils/request.py | 71 +++++++++++++++++++++++ pdm.lock | 8 +-- pyproject.toml | 2 +- 6 files changed, 111 insertions(+), 76 deletions(-) create mode 100644 nonebot_plugin_marshoai/utils/request.py diff --git a/nonebot_plugin_marshoai/handler.py b/nonebot_plugin_marshoai/handler.py index 4949785..f04285e 100644 --- a/nonebot_plugin_marshoai/handler.py +++ b/nonebot_plugin_marshoai/handler.py @@ -17,10 +17,9 @@ from nonebot.matcher import ( current_event, current_matcher, ) -from nonebot_plugin_alconna.uniseg import UniMessage, UniMsg +from nonebot_plugin_alconna.uniseg import UniMessage, UniMsg, get_message_id, get_target from openai import AsyncOpenAI, AsyncStream from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage -from openai.types.chat.chat_completion import Choice from .config import config from .constants import SUPPORT_IMAGE_MODELS @@ -36,6 +35,7 @@ from .util import ( make_chat_openai, parse_richtext, ) +from .utils.request import process_chat_stream class MarshoHandler: @@ -50,8 +50,8 @@ class MarshoHandler: self.event: Event = current_event.get() # self.state: T_State = current_handler.get().state self.matcher: Matcher = current_matcher.get() - self.message_id: str = UniMessage.get_message_id(self.event) - self.target = UniMessage.get_target(self.event) + self.message_id: str = get_message_id(self.event) + self.target = get_target(self.event) async def process_user_input( self, user_input: UniMsg, model_name: str @@ -117,10 +117,10 @@ class MarshoHandler: async def handle_function_call( self, - completion: Union[ChatCompletion, AsyncStream[ChatCompletionChunk]], + completion: Union[ChatCompletion], user_message: Union[str, list], model_name: str, - tools_list: list, + tools_list: list | None = None, ): # function call # 需要获取额外信息,调用函数工具 @@ -188,7 +188,7 @@ class MarshoHandler: self, user_message: Union[str, list], model_name: str, - tools_list: list, + tools_list: list | None = None, stream: bool = False, tool_message: Optional[list] = None, ) -> Optional[Tuple[UserMessage, ChatCompletionMessage]]: @@ -257,9 +257,9 @@ class MarshoHandler: self, user_message: Union[str, list], model_name: str, - tools_list: list, + tools_list: list | None = None, tools_message: Optional[list] = None, - ) -> Union[ChatCompletion, None]: + ) -> ChatCompletion: """ 处理流式请求 """ @@ -272,56 +272,6 @@ class MarshoHandler: ) if isinstance(response, AsyncStream): - reasoning_contents = "" - answer_contents = "" - last_chunk = None - is_first_token_appeared = False - is_answering = False - async for chunk in response: - last_chunk = chunk - # print(chunk) - if not is_first_token_appeared: - logger.debug(f"{chunk.id}: 第一个 token 已出现") - is_first_token_appeared = True - if not chunk.choices: - logger.info("Usage:", chunk.usage) - else: - delta = chunk.choices[0].delta - if ( - hasattr(delta, "reasoning_content") - and delta.reasoning_content is not None - ): - reasoning_contents += delta.reasoning_content - else: - if not is_answering: - logger.debug( - f"{chunk.id}: 思维链已输出完毕或无 reasoning_content 字段输出" - ) - is_answering = True - if delta.content is not None: - answer_contents += delta.content - # print(last_chunk) - # 创建新的 ChatCompletion 对象 - if last_chunk and last_chunk.choices: - message = ChatCompletionMessage( - content=answer_contents, - role="assistant", - tool_calls=last_chunk.choices[0].delta.tool_calls, # type: ignore - ) - if reasoning_contents != "": - setattr(message, "reasoning_content", reasoning_contents) - choice = Choice( - finish_reason=last_chunk.choices[0].finish_reason, # type: ignore - index=last_chunk.choices[0].index, - message=message, - ) - return ChatCompletion( - id=last_chunk.id, - choices=[choice], - created=last_chunk.created, - model=last_chunk.model, - system_fingerprint=last_chunk.system_fingerprint, - object="chat.completion", - usage=last_chunk.usage, - ) - return None + return await process_chat_stream(response) + else: + raise TypeError("Unexpected response type for stream request") diff --git a/nonebot_plugin_marshoai/marsho.py b/nonebot_plugin_marshoai/marsho.py index 1944b25..8ef3ffd 100644 --- a/nonebot_plugin_marshoai/marsho.py +++ b/nonebot_plugin_marshoai/marsho.py @@ -15,7 +15,14 @@ from nonebot.params import CommandArg from nonebot.permission import SUPERUSER from nonebot.rule import to_me from nonebot.typing import T_State -from nonebot_plugin_alconna import MsgTarget, UniMessage, UniMsg, on_alconna +from nonebot_plugin_alconna import ( + Emoji, + MsgTarget, + UniMessage, + UniMsg, + message_reaction, + on_alconna, +) from .config import config from .constants import INTRODUCTION, SUPPORT_IMAGE_MODELS @@ -25,6 +32,7 @@ from .instances import client, context, model_name, target_list, tools from .metadata import metadata from .plugin.func_call.caller import get_function_calls from .util import * +from .utils.request import process_chat_stream async def at_enable(): @@ -226,6 +234,7 @@ async def marsho( if not text: # 发送说明 # await UniMessage(metadata.usage + "\n当前使用的模型:" + model_name).send() + await message_reaction(Emoji("38")) await marsho_cmd.finish(INTRODUCTION) backup_context = await get_backup_context(target.id, target.private) if backup_context: @@ -256,6 +265,7 @@ async def marsho( map(lambda v: v.data(), get_function_calls().values()) ) logger.info(f"正在获取回答,模型:{model_name}") + await message_reaction(Emoji("66")) # logger.info(f"上下文:{context_msg}") response = await handler.handle_common_chat( usermsg, model_name, tools_lists, config.marshoai_stream @@ -282,19 +292,23 @@ with contextlib.suppress(ImportError): # 优化先不做() async def poke(event: Event): user_nickname = await get_nickname_by_user_id(event.get_user_id()) + usermsg = await get_prompt(model_name) + [ + UserMessage(content=f"*{user_nickname}{config.marshoai_poke_suffix}"), + ] try: if config.marshoai_poke_suffix != "": logger.info(f"收到戳一戳,用户昵称:{user_nickname}") - response = await make_chat_openai( + + pre_response = await make_chat_openai( client=client, model_name=model_name, - msg=await get_prompt(model_name) - + [ - UserMessage( - content=f"*{user_nickname}{config.marshoai_poke_suffix}" - ), - ], + msg=usermsg, + stream=config.marshoai_stream, ) + if isinstance(pre_response, AsyncStream): + response = await process_chat_stream(pre_response) + else: + response = pre_response choice = response.choices[0] # type: ignore if choice.finish_reason == CompletionsFinishReason.STOPPED: content = extract_content_and_think(choice.message)[0] diff --git a/nonebot_plugin_marshoai/util.py b/nonebot_plugin_marshoai/util.py index 9d69efa..1dd9436 100755 --- a/nonebot_plugin_marshoai/util.py +++ b/nonebot_plugin_marshoai/util.py @@ -18,7 +18,7 @@ from nonebot_plugin_alconna import Text as TextMsg from nonebot_plugin_alconna import UniMessage from openai import AsyncOpenAI, AsyncStream, NotGiven from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage -from zhDateTime import DateTime +from zhDateTime import DateTime # type: ignore from ._types import DeveloperMessage from .cache.decos import * diff --git a/nonebot_plugin_marshoai/utils/request.py b/nonebot_plugin_marshoai/utils/request.py new file mode 100644 index 0000000..ae83d11 --- /dev/null +++ b/nonebot_plugin_marshoai/utils/request.py @@ -0,0 +1,71 @@ +from nonebot.log import logger +from openai import AsyncStream +from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage +from openai.types.chat.chat_completion import Choice + + +async def process_chat_stream( + stream: AsyncStream[ChatCompletionChunk], +) -> ChatCompletion: + reasoning_contents = "" + answer_contents = "" + last_chunk = None + is_first_token_appeared = False + is_answering = False + async for chunk in stream: + last_chunk = chunk + # print(chunk) + if not is_first_token_appeared: + logger.debug(f"{chunk.id}: 第一个 token 已出现") + is_first_token_appeared = True + if not chunk.choices: + logger.info("Usage:", chunk.usage) + else: + delta = chunk.choices[0].delta + if ( + hasattr(delta, "reasoning_content") + and delta.reasoning_content is not None + ): + reasoning_contents += delta.reasoning_content + else: + if not is_answering: + logger.debug( + f"{chunk.id}: 思维链已输出完毕或无 reasoning_content 字段输出" + ) + is_answering = True + if delta.content is not None: + answer_contents += delta.content + # print(last_chunk) + # 创建新的 ChatCompletion 对象 + if last_chunk and last_chunk.choices: + message = ChatCompletionMessage( + content=answer_contents, + role="assistant", + tool_calls=last_chunk.choices[0].delta.tool_calls, # type: ignore + ) + if reasoning_contents != "": + setattr(message, "reasoning_content", reasoning_contents) + choice = Choice( + finish_reason=last_chunk.choices[0].finish_reason, # type: ignore + index=last_chunk.choices[0].index, + message=message, + ) + return ChatCompletion( + id=last_chunk.id, + choices=[choice], + created=last_chunk.created, + model=last_chunk.model, + system_fingerprint=last_chunk.system_fingerprint, + object="chat.completion", + usage=last_chunk.usage, + ) + else: + return ChatCompletion( + id="", + choices=[], + created=0, + model="", + system_fingerprint="", + object="chat.completion", + usage=None, + ) diff --git a/pdm.lock b/pdm.lock index 6e8a898..a4b3650 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "dev", "test"] strategy = ["inherit_metadata"] lock_version = "4.5.0" -content_hash = "sha256:d7ab3d9ca825de512d4f87ec846f7fddcf3d5796a7c9562e60c8c7d39c058817" +content_hash = "sha256:9dd3edfe69c332deac360af2685358e82c5dac0870900668534fc6f1d34040f8" [[metadata.targets]] requires_python = "~=3.10" @@ -1485,7 +1485,7 @@ files = [ [[package]] name = "nonebot-plugin-alconna" -version = "0.54.1" +version = "0.57.0" requires_python = ">=3.9" summary = "Alconna Adapter for Nonebot" groups = ["default"] @@ -1499,8 +1499,8 @@ dependencies = [ "tarina<0.7,>=0.6.8", ] files = [ - {file = "nonebot_plugin_alconna-0.54.1-py3-none-any.whl", hash = "sha256:4edb4b081cd64ce37717c7a92d31aadd2cf287a5a0adc2ac86ed82d9bcad5048"}, - {file = "nonebot_plugin_alconna-0.54.1.tar.gz", hash = "sha256:66fae03120b8eff25bb0027d65f149e399aa6f73c7585ebdd388d1904cecdeee"}, + {file = "nonebot_plugin_alconna-0.57.0-py3-none-any.whl", hash = "sha256:6c4bcce1a9aa176244b4c011b19b1cea00269c4c6794cd4e90d8dd7990ec3ec9"}, + {file = "nonebot_plugin_alconna-0.57.0.tar.gz", hash = "sha256:7a9a4bf373f3f6836611dbde1a0917b84441a534dd6f2b20dae3ba6fff142858"}, ] [[package]] diff --git a/pyproject.toml b/pyproject.toml index ec57a5a..57d9167 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ authors = [ ] dependencies = [ "nonebot2>=2.4.0", - "nonebot-plugin-alconna>=0.48.0", + "nonebot-plugin-alconna>=0.57.0", "nonebot-plugin-localstore>=0.7.1", "zhDatetime>=2.0.0", "aiohttp>=3.9",