mirror of
				https://github.com/nonebot/nonebot2.git
				synced 2025-10-26 04:26:39 +00:00 
			
		
		
		
	🐛 Fix: State ForwardRef 检测错误 (#2698)
This commit is contained in:
		| @@ -17,8 +17,14 @@ from pydantic.fields import FieldInfo as PydanticFieldInfo | ||||
|  | ||||
| from nonebot.dependencies import Param, Dependent | ||||
| from nonebot.dependencies.utils import check_field_type | ||||
| from nonebot.typing import T_State, T_Handler, T_DependencyCache | ||||
| from nonebot.compat import FieldInfo, ModelField, PydanticUndefined, extract_field_info | ||||
| from nonebot.typing import ( | ||||
|     _STATE_FLAG, | ||||
|     T_State, | ||||
|     T_Handler, | ||||
|     T_DependencyCache, | ||||
|     origin_is_annotated, | ||||
| ) | ||||
| from nonebot.utils import ( | ||||
|     get_name, | ||||
|     run_sync, | ||||
| @@ -349,7 +355,9 @@ class StateParam(Param): | ||||
|         cls, param: inspect.Parameter, allow_types: tuple[type[Param], ...] | ||||
|     ) -> Optional[Self]: | ||||
|         # param type is T_State | ||||
|         if param.annotation is T_State: | ||||
|         if origin_is_annotated( | ||||
|             get_origin(param.annotation) | ||||
|         ) and _STATE_FLAG in get_args(param.annotation): | ||||
|             return cls() | ||||
|         # legacy: param is named "state" and has no type annotation | ||||
|         elif param.annotation == param.empty and param.name == "state": | ||||
|   | ||||
| @@ -108,7 +108,15 @@ def evaluate_forwardref( | ||||
|  | ||||
|  | ||||
| # state | ||||
| T_State: TypeAlias = dict[t.Any, t.Any] | ||||
| # use annotated flag to avoid ForwardRef recreate generic type (py >= 3.11) | ||||
| class StateFlag: | ||||
|     def __repr__(self) -> str: | ||||
|         return "StateFlag()" | ||||
|  | ||||
|  | ||||
| _STATE_FLAG = StateFlag() | ||||
|  | ||||
| T_State: TypeAlias = t.Annotated[dict[t.Any, t.Any], _STATE_FLAG] | ||||
| """事件处理状态 State 类型""" | ||||
|  | ||||
| _DependentCallable: TypeAlias = t.Union[ | ||||
|   | ||||
| @@ -7,6 +7,10 @@ async def get_bot(b: Bot) -> Bot: | ||||
|     return b | ||||
|  | ||||
|  | ||||
| async def postpone_bot(b: "Bot") -> Bot: | ||||
|     return b | ||||
|  | ||||
|  | ||||
| async def legacy_bot(bot): | ||||
|     return bot | ||||
|  | ||||
|   | ||||
| @@ -8,6 +8,10 @@ async def event(e: Event) -> Event: | ||||
|     return e | ||||
|  | ||||
|  | ||||
| async def postpone_event(e: "Event") -> Event: | ||||
|     return e | ||||
|  | ||||
|  | ||||
| async def legacy_event(event): | ||||
|     return event | ||||
|  | ||||
|   | ||||
| @@ -9,6 +9,10 @@ async def matcher(m: Matcher) -> Matcher: | ||||
|     return m | ||||
|  | ||||
|  | ||||
| async def postpone_matcher(m: "Matcher") -> Matcher: | ||||
|     return m | ||||
|  | ||||
|  | ||||
| async def legacy_matcher(matcher): | ||||
|     return matcher | ||||
|  | ||||
| @@ -27,7 +31,7 @@ class BarMatcher(Matcher): ... | ||||
|  | ||||
|  | ||||
| async def union_matcher( | ||||
|     m: Union[FooMatcher, BarMatcher] | ||||
|     m: Union[FooMatcher, BarMatcher], | ||||
| ) -> Union[FooMatcher, BarMatcher]: | ||||
|     return m | ||||
|  | ||||
|   | ||||
| @@ -25,6 +25,10 @@ async def state(x: T_State) -> T_State: | ||||
|     return x | ||||
|  | ||||
|  | ||||
| async def postpone_state(x: "T_State") -> T_State: | ||||
|     return x | ||||
|  | ||||
|  | ||||
| async def legacy_state(state): | ||||
|     return state | ||||
|  | ||||
|   | ||||
| @@ -129,6 +129,7 @@ async def test_bot(app: App): | ||||
|         union_bot, | ||||
|         legacy_bot, | ||||
|         generic_bot, | ||||
|         postpone_bot, | ||||
|         not_legacy_bot, | ||||
|         generic_bot_none, | ||||
|     ) | ||||
| @@ -138,6 +139,11 @@ async def test_bot(app: App): | ||||
|         ctx.pass_params(bot=bot) | ||||
|         ctx.should_return(bot) | ||||
|  | ||||
|     async with app.test_dependent(postpone_bot, allow_types=[BotParam]) as ctx: | ||||
|         bot = ctx.create_bot() | ||||
|         ctx.pass_params(bot=bot) | ||||
|         ctx.should_return(bot) | ||||
|  | ||||
|     async with app.test_dependent(legacy_bot, allow_types=[BotParam]) as ctx: | ||||
|         bot = ctx.create_bot() | ||||
|         ctx.pass_params(bot=bot) | ||||
| @@ -188,6 +194,7 @@ async def test_event(app: App): | ||||
|         legacy_event, | ||||
|         event_message, | ||||
|         generic_event, | ||||
|         postpone_event, | ||||
|         event_plain_text, | ||||
|         not_legacy_event, | ||||
|         generic_event_none, | ||||
| @@ -201,6 +208,10 @@ async def test_event(app: App): | ||||
|         ctx.pass_params(event=fake_event) | ||||
|         ctx.should_return(fake_event) | ||||
|  | ||||
|     async with app.test_dependent(postpone_event, allow_types=[EventParam]) as ctx: | ||||
|         ctx.pass_params(event=fake_event) | ||||
|         ctx.should_return(fake_event) | ||||
|  | ||||
|     async with app.test_dependent(legacy_event, allow_types=[EventParam]) as ctx: | ||||
|         ctx.pass_params(event=fake_event) | ||||
|         ctx.should_return(fake_event) | ||||
| @@ -273,6 +284,7 @@ async def test_state(app: App): | ||||
|         legacy_state, | ||||
|         command_start, | ||||
|         regex_matched, | ||||
|         postpone_state, | ||||
|         not_legacy_state, | ||||
|         command_whitespace, | ||||
|         shell_command_args, | ||||
| @@ -302,6 +314,10 @@ async def test_state(app: App): | ||||
|         ctx.pass_params(state=fake_state) | ||||
|         ctx.should_return(fake_state) | ||||
|  | ||||
|     async with app.test_dependent(postpone_state, allow_types=[StateParam]) as ctx: | ||||
|         ctx.pass_params(state=fake_state) | ||||
|         ctx.should_return(fake_state) | ||||
|  | ||||
|     async with app.test_dependent(legacy_state, allow_types=[StateParam]) as ctx: | ||||
|         ctx.pass_params(state=fake_state) | ||||
|         ctx.should_return(fake_state) | ||||
| @@ -414,6 +430,7 @@ async def test_matcher(app: App): | ||||
|         union_matcher, | ||||
|         legacy_matcher, | ||||
|         generic_matcher, | ||||
|         postpone_matcher, | ||||
|         not_legacy_matcher, | ||||
|         generic_matcher_none, | ||||
|     ) | ||||
| @@ -425,6 +442,10 @@ async def test_matcher(app: App): | ||||
|         ctx.pass_params(matcher=fake_matcher) | ||||
|         ctx.should_return(fake_matcher) | ||||
|  | ||||
|     async with app.test_dependent(postpone_matcher, allow_types=[MatcherParam]) as ctx: | ||||
|         ctx.pass_params(matcher=fake_matcher) | ||||
|         ctx.should_return(fake_matcher) | ||||
|  | ||||
|     async with app.test_dependent(legacy_matcher, allow_types=[MatcherParam]) as ctx: | ||||
|         ctx.pass_params(matcher=fake_matcher) | ||||
|         ctx.should_return(fake_matcher) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user