🐛 fix 通道类回调函数在进程间传递时无法序列号的问题

This commit is contained in:
2024-08-10 22:25:41 +08:00
parent 3bd40e7271
commit 7107d03b72
66 changed files with 5112 additions and 4916 deletions

View File

@ -9,11 +9,22 @@ Copyright (C) 2020-2024 LiteyukiStudio. All Rights Reserved
@Software: PyCharm
该模块用于轻雪主进程和Nonebot子进程之间的通信
"""
from liteyuki.comm.channel import Channel, chan
from liteyuki.comm.channel import (
Channel,
chan,
get_channel,
set_channel,
set_channels,
get_channels
)
from liteyuki.comm.event import Event
__all__ = [
"Channel",
"chan",
"Event",
"get_channel",
"set_channel",
"set_channels",
"get_channels"
]

View File

@ -10,8 +10,12 @@ Copyright (C) 2020-2024 LiteyukiStudio. All Rights Reserved
本模块定义了一个通用的通道类,用于进程间通信
"""
import functools
import multiprocessing
import threading
from multiprocessing import Pipe
from typing import Any, Optional, Callable, Awaitable, List, TypeAlias
from uuid import uuid4
from liteyuki.utils import is_coroutine_callable, run_coroutine
@ -23,76 +27,89 @@ SYNC_FILTER_FUNC: TypeAlias = Callable[[Any], bool]
ASYNC_FILTER_FUNC: TypeAlias = Callable[[Any], Awaitable[bool]]
FILTER_FUNC: TypeAlias = SYNC_FILTER_FUNC | ASYNC_FILTER_FUNC
IS_MAIN_PROCESS = multiprocessing.current_process().name == "MainProcess"
_channel: dict[str, "Channel"] = {}
_callback_funcs: dict[str, ON_RECEIVE_FUNC] = {}
class Channel:
"""
通道类,用于进程间通信
通道类,用于进程间通信,进程内不可用,仅限主进程和子进程之间通信
有两种接收工作方式,但是只能选择一种,主动接收和被动接收,主动接收使用 `receive` 方法,被动接收使用 `on_receive` 装饰器
"""
def __init__(self):
self.receive_conn, self.send_conn = Pipe()
def __init__(self, _id: str):
self.main_send_conn, self.sub_receive_conn = Pipe()
self.sub_send_conn, self.main_receive_conn = Pipe()
self._closed = False
self._on_receive_funcs: List[ON_RECEIVE_FUNC] = []
self._on_receive_funcs_with_receiver: dict[str, List[ON_RECEIVE_FUNC]] = {}
self._on_main_receive_funcs: list[str] = []
self._on_sub_receive_funcs: list[str] = []
self.name: str = _id
def send(self, data: Any, receiver: Optional[str] = None):
self.is_main_receive_loop_running = False
self.is_sub_receive_loop_running = False
def __str__(self):
return f"Channel({self.name})"
def send(self, data: Any):
"""
发送数据
Args:
data: 数据
receiver: 接收者如果为None则广播
"""
if self._closed:
raise RuntimeError("Cannot send to a closed channel")
self.send_conn.send((data, receiver))
if IS_MAIN_PROCESS:
print("主进程发送数据:", data)
self.main_send_conn.send(data)
else:
print("子进程发送数据:", data)
self.sub_send_conn.send(data)
def receive(self, receiver: str = None) -> Any:
def receive(self) -> Any:
"""
接收数据
Args:
receiver: 接收者如果为None则接收任意数据
"""
if self._closed:
raise RuntimeError("Cannot receive from a closed channel")
while True:
# 判断receiver是否为None或者receiver是否等于接收者是则接收数据否则不动数据
data, receiver_ = self.receive_conn.recv()
if receiver is None or receiver == receiver_:
self._run_on_receive_funcs(data, receiver_)
return data
self.send_conn.send((data, receiver_))
if IS_MAIN_PROCESS:
data = self.main_receive_conn.recv()
print("主进程接收数据:", data)
else:
data = self.sub_receive_conn.recv()
print("子进程接收数据:", data)
def peek(self) -> Optional[Any]:
"""
查看管道中的数据,不移除
Returns:
"""
if self._closed:
raise RuntimeError("Cannot peek from a closed channel")
if self.receive_conn.poll():
data, receiver = self.receive_conn.recv()
self.receive_conn.send((data, receiver))
return data
return None
def close(self):
"""
关闭通道
"""
self._closed = True
self.receive_conn.close()
self.send_conn.close()
self.sub_receive_conn.close()
self.main_send_conn.close()
self.sub_send_conn.close()
self.main_receive_conn.close()
def on_receive(self, filter_func: Optional[FILTER_FUNC] = None, receiver: Optional[str] = None) -> Callable[[ON_RECEIVE_FUNC], ON_RECEIVE_FUNC]:
def on_receive(self, filter_func: Optional[FILTER_FUNC] = None) -> Callable[[ON_RECEIVE_FUNC], ON_RECEIVE_FUNC]:
"""
接收数据并执行函数
Args:
filter_func: 过滤函数为None则不过滤
receiver: 接收者, 为None则接收任意数据
Returns:
装饰器,装饰一个函数在接收到数据后执行
"""
if (not self.is_sub_receive_loop_running) and not IS_MAIN_PROCESS:
threading.Thread(target=self._start_sub_receive_loop).start()
if (not self.is_main_receive_loop_running) and IS_MAIN_PROCESS:
threading.Thread(target=self._start_main_receive_loop).start()
def decorator(func: ON_RECEIVE_FUNC) -> ON_RECEIVE_FUNC:
async def wrapper(data: Any) -> Any:
@ -105,28 +122,53 @@ class Channel:
return
return await func(data)
if receiver is None:
self._on_receive_funcs.append(wrapper)
function_id = str(uuid4())
_callback_funcs[function_id] = wrapper
if IS_MAIN_PROCESS:
self._on_main_receive_funcs.append(function_id)
else:
if receiver not in self._on_receive_funcs_with_receiver:
self._on_receive_funcs_with_receiver[receiver] = []
self._on_receive_funcs_with_receiver[receiver].append(wrapper)
self._on_sub_receive_funcs.append(function_id)
return func
return decorator
def _run_on_receive_funcs(self, data: Any, receiver: Optional[str] = None):
def _run_on_main_receive_funcs(self, data: Any):
"""
运行接收函数
Args:
data: 数据
"""
if receiver is None:
for func in self._on_receive_funcs:
run_coroutine(func(data))
else:
for func in self._on_receive_funcs_with_receiver.get(receiver, []):
run_coroutine(func(data))
for func_id in self._on_main_receive_funcs:
func = _callback_funcs[func_id]
run_coroutine(func(data))
def _run_on_sub_receive_funcs(self, data: Any):
"""
运行接收函数
Args:
data: 数据
"""
for func_id in self._on_sub_receive_funcs:
func = _callback_funcs[func_id]
run_coroutine(func(data))
def _start_main_receive_loop(self):
"""
开始接收数据
"""
self.is_main_receive_loop_running = True
while not self._closed:
data = self.main_receive_conn.recv()
self._run_on_main_receive_funcs(data)
def _start_sub_receive_loop(self):
"""
开始接收数据
"""
self.is_sub_receive_loop_running = True
while not self._closed:
data = self.sub_receive_conn.recv()
self._run_on_sub_receive_funcs(data)
def __iter__(self):
return self
@ -136,4 +178,42 @@ class Channel:
"""默认通道实例,可直接从模块导入使用"""
chan = Channel()
chan = Channel("default")
def set_channel(name: str, channel: Channel):
"""
设置通道实例
Args:
name: 通道名称
channel: 通道实例
"""
_channel[name] = channel
def set_channels(channels: dict[str, Channel]):
"""
设置通道实例
Args:
channels: 通道名称
"""
for name, channel in channels.items():
_channel[name] = channel
def get_channel(name: str) -> Optional[Channel]:
"""
获取通道实例
Args:
name: 通道名称
Returns:
"""
return _channel.get(name, None)
def get_channels() -> dict[str, Channel]:
"""
获取通道实例
Returns:
"""
return _channel