mirror of
				https://github.com/LiteyukiStudio/LiteyukiBot.git
				synced 2025-10-30 23:46:30 +00:00 
			
		
		
		
	🐛 fix 通道类回调函数在进程间传递时无法序列号的问题
This commit is contained in:
		| @@ -9,11 +9,22 @@ Copyright (C) 2020-2024 LiteyukiStudio. All Rights Reserved | ||||
| @Software: PyCharm | ||||
| 该模块用于轻雪主进程和Nonebot子进程之间的通信 | ||||
| """ | ||||
| from liteyuki.comm.channel import Channel, chan | ||||
| from liteyuki.comm.channel import ( | ||||
|     Channel, | ||||
|     chan, | ||||
|     get_channel, | ||||
|     set_channel, | ||||
|     set_channels, | ||||
|     get_channels | ||||
| ) | ||||
| from liteyuki.comm.event import Event | ||||
|  | ||||
| __all__ = [ | ||||
|         "Channel", | ||||
|         "chan", | ||||
|         "Event", | ||||
|         "get_channel", | ||||
|         "set_channel", | ||||
|         "set_channels", | ||||
|         "get_channels" | ||||
| ] | ||||
|   | ||||
| @@ -10,8 +10,12 @@ Copyright (C) 2020-2024 LiteyukiStudio. All Rights Reserved | ||||
|  | ||||
| 本模块定义了一个通用的通道类,用于进程间通信 | ||||
| """ | ||||
| import functools | ||||
| import multiprocessing | ||||
| import threading | ||||
| from multiprocessing import Pipe | ||||
| from typing import Any, Optional, Callable, Awaitable, List, TypeAlias | ||||
| from uuid import uuid4 | ||||
|  | ||||
| from liteyuki.utils import is_coroutine_callable, run_coroutine | ||||
|  | ||||
| @@ -23,76 +27,89 @@ SYNC_FILTER_FUNC: TypeAlias = Callable[[Any], bool] | ||||
| ASYNC_FILTER_FUNC: TypeAlias = Callable[[Any], Awaitable[bool]] | ||||
| FILTER_FUNC: TypeAlias = SYNC_FILTER_FUNC | ASYNC_FILTER_FUNC | ||||
|  | ||||
| IS_MAIN_PROCESS = multiprocessing.current_process().name == "MainProcess" | ||||
|  | ||||
| _channel: dict[str, "Channel"] = {} | ||||
| _callback_funcs: dict[str, ON_RECEIVE_FUNC] = {} | ||||
|  | ||||
|  | ||||
| class Channel: | ||||
|     """ | ||||
|     通道类,用于进程间通信 | ||||
|     通道类,用于进程间通信,进程内不可用,仅限主进程和子进程之间通信 | ||||
|     有两种接收工作方式,但是只能选择一种,主动接收和被动接收,主动接收使用 `receive` 方法,被动接收使用 `on_receive` 装饰器 | ||||
|     """ | ||||
|  | ||||
|     def __init__(self): | ||||
|         self.receive_conn, self.send_conn = Pipe() | ||||
|     def __init__(self, _id: str): | ||||
|         self.main_send_conn, self.sub_receive_conn = Pipe() | ||||
|         self.sub_send_conn, self.main_receive_conn = Pipe() | ||||
|         self._closed = False | ||||
|         self._on_receive_funcs: List[ON_RECEIVE_FUNC] = [] | ||||
|         self._on_receive_funcs_with_receiver: dict[str, List[ON_RECEIVE_FUNC]] = {} | ||||
|         self._on_main_receive_funcs: list[str] = [] | ||||
|         self._on_sub_receive_funcs: list[str] = [] | ||||
|         self.name: str = _id | ||||
|  | ||||
|     def send(self, data: Any, receiver: Optional[str] = None): | ||||
|         self.is_main_receive_loop_running = False | ||||
|         self.is_sub_receive_loop_running = False | ||||
|  | ||||
|     def __str__(self): | ||||
|         return f"Channel({self.name})" | ||||
|  | ||||
|     def send(self, data: Any): | ||||
|         """ | ||||
|         发送数据 | ||||
|         Args: | ||||
|             data: 数据 | ||||
|             receiver: 接收者,如果为None则广播 | ||||
|         """ | ||||
|         if self._closed: | ||||
|             raise RuntimeError("Cannot send to a closed channel") | ||||
|         self.send_conn.send((data, receiver)) | ||||
|         if IS_MAIN_PROCESS: | ||||
|             print("主进程发送数据:", data) | ||||
|             self.main_send_conn.send(data) | ||||
|         else: | ||||
|             print("子进程发送数据:", data) | ||||
|             self.sub_send_conn.send(data) | ||||
|  | ||||
|     def receive(self, receiver: str = None) -> Any: | ||||
|     def receive(self) -> Any: | ||||
|         """ | ||||
|         接收数据 | ||||
|         Args: | ||||
|             receiver: 接收者,如果为None则接收任意数据 | ||||
|         """ | ||||
|         if self._closed: | ||||
|             raise RuntimeError("Cannot receive from a closed channel") | ||||
|  | ||||
|         while True: | ||||
|             # 判断receiver是否为None或者receiver是否等于接收者,是则接收数据,否则不动数据 | ||||
|             data, receiver_ = self.receive_conn.recv() | ||||
|             if receiver is None or receiver == receiver_: | ||||
|                 self._run_on_receive_funcs(data, receiver_) | ||||
|                 return data | ||||
|             self.send_conn.send((data, receiver_)) | ||||
|             if IS_MAIN_PROCESS: | ||||
|                 data = self.main_receive_conn.recv() | ||||
|                 print("主进程接收数据:", data) | ||||
|             else: | ||||
|                 data = self.sub_receive_conn.recv() | ||||
|                 print("子进程接收数据:", data) | ||||
|  | ||||
|     def peek(self) -> Optional[Any]: | ||||
|         """ | ||||
|         查看管道中的数据,不移除 | ||||
|         Returns: | ||||
|         """ | ||||
|         if self._closed: | ||||
|             raise RuntimeError("Cannot peek from a closed channel") | ||||
|         if self.receive_conn.poll(): | ||||
|             data, receiver = self.receive_conn.recv() | ||||
|             self.receive_conn.send((data, receiver)) | ||||
|             return data | ||||
|         return None | ||||
|  | ||||
|     def close(self): | ||||
|         """ | ||||
|         关闭通道 | ||||
|         """ | ||||
|         self._closed = True | ||||
|         self.receive_conn.close() | ||||
|         self.send_conn.close() | ||||
|         self.sub_receive_conn.close() | ||||
|         self.main_send_conn.close() | ||||
|         self.sub_send_conn.close() | ||||
|         self.main_receive_conn.close() | ||||
|  | ||||
|     def on_receive(self, filter_func: Optional[FILTER_FUNC] = None, receiver: Optional[str] = None) -> Callable[[ON_RECEIVE_FUNC], ON_RECEIVE_FUNC]: | ||||
|     def on_receive(self, filter_func: Optional[FILTER_FUNC] = None) -> Callable[[ON_RECEIVE_FUNC], ON_RECEIVE_FUNC]: | ||||
|         """ | ||||
|         接收数据并执行函数 | ||||
|         Args: | ||||
|             filter_func: 过滤函数,为None则不过滤 | ||||
|             receiver: 接收者, 为None则接收任意数据 | ||||
|         Returns: | ||||
|             装饰器,装饰一个函数在接收到数据后执行 | ||||
|         """ | ||||
|         if (not self.is_sub_receive_loop_running) and not IS_MAIN_PROCESS: | ||||
|             threading.Thread(target=self._start_sub_receive_loop).start() | ||||
|  | ||||
|         if (not self.is_main_receive_loop_running) and IS_MAIN_PROCESS: | ||||
|             threading.Thread(target=self._start_main_receive_loop).start() | ||||
|  | ||||
|         def decorator(func: ON_RECEIVE_FUNC) -> ON_RECEIVE_FUNC: | ||||
|             async def wrapper(data: Any) -> Any: | ||||
| @@ -105,28 +122,53 @@ class Channel: | ||||
|                             return | ||||
|                 return await func(data) | ||||
|  | ||||
|             if receiver is None: | ||||
|                 self._on_receive_funcs.append(wrapper) | ||||
|             function_id = str(uuid4()) | ||||
|             _callback_funcs[function_id] = wrapper | ||||
|             if IS_MAIN_PROCESS: | ||||
|                 self._on_main_receive_funcs.append(function_id) | ||||
|             else: | ||||
|                 if receiver not in self._on_receive_funcs_with_receiver: | ||||
|                     self._on_receive_funcs_with_receiver[receiver] = [] | ||||
|                 self._on_receive_funcs_with_receiver[receiver].append(wrapper) | ||||
|                 self._on_sub_receive_funcs.append(function_id) | ||||
|             return func | ||||
|  | ||||
|         return decorator | ||||
|  | ||||
|     def _run_on_receive_funcs(self, data: Any, receiver: Optional[str] = None): | ||||
|     def _run_on_main_receive_funcs(self, data: Any): | ||||
|         """ | ||||
|         运行接收函数 | ||||
|         Args: | ||||
|             data: 数据 | ||||
|         """ | ||||
|         if receiver is None: | ||||
|             for func in self._on_receive_funcs: | ||||
|                 run_coroutine(func(data)) | ||||
|         else: | ||||
|             for func in self._on_receive_funcs_with_receiver.get(receiver, []): | ||||
|                 run_coroutine(func(data)) | ||||
|         for func_id in self._on_main_receive_funcs: | ||||
|             func = _callback_funcs[func_id] | ||||
|             run_coroutine(func(data)) | ||||
|  | ||||
|     def _run_on_sub_receive_funcs(self, data: Any): | ||||
|         """ | ||||
|         运行接收函数 | ||||
|         Args: | ||||
|             data: 数据 | ||||
|         """ | ||||
|         for func_id in self._on_sub_receive_funcs: | ||||
|             func = _callback_funcs[func_id] | ||||
|             run_coroutine(func(data)) | ||||
|  | ||||
|     def _start_main_receive_loop(self): | ||||
|         """ | ||||
|         开始接收数据 | ||||
|         """ | ||||
|         self.is_main_receive_loop_running = True | ||||
|         while not self._closed: | ||||
|             data = self.main_receive_conn.recv() | ||||
|             self._run_on_main_receive_funcs(data) | ||||
|  | ||||
|     def _start_sub_receive_loop(self): | ||||
|         """ | ||||
|         开始接收数据 | ||||
|         """ | ||||
|         self.is_sub_receive_loop_running = True | ||||
|         while not self._closed: | ||||
|             data = self.sub_receive_conn.recv() | ||||
|             self._run_on_sub_receive_funcs(data) | ||||
|  | ||||
|     def __iter__(self): | ||||
|         return self | ||||
| @@ -136,4 +178,42 @@ class Channel: | ||||
|  | ||||
|  | ||||
| """默认通道实例,可直接从模块导入使用""" | ||||
| chan = Channel() | ||||
| chan = Channel("default") | ||||
|  | ||||
|  | ||||
| def set_channel(name: str, channel: Channel): | ||||
|     """ | ||||
|     设置通道实例 | ||||
|     Args: | ||||
|         name: 通道名称 | ||||
|         channel: 通道实例 | ||||
|     """ | ||||
|     _channel[name] = channel | ||||
|  | ||||
|  | ||||
| def set_channels(channels: dict[str, Channel]): | ||||
|     """ | ||||
|     设置通道实例 | ||||
|     Args: | ||||
|         channels: 通道名称 | ||||
|     """ | ||||
|     for name, channel in channels.items(): | ||||
|         _channel[name] = channel | ||||
|  | ||||
|  | ||||
| def get_channel(name: str) -> Optional[Channel]: | ||||
|     """ | ||||
|     获取通道实例 | ||||
|     Args: | ||||
|         name: 通道名称 | ||||
|     Returns: | ||||
|     """ | ||||
|     return _channel.get(name, None) | ||||
|  | ||||
|  | ||||
| def get_channels() -> dict[str, Channel]: | ||||
|     """ | ||||
|     获取通道实例 | ||||
|     Returns: | ||||
|     """ | ||||
|     return _channel | ||||
|   | ||||
		Reference in New Issue
	
	Block a user