mirror of
				https://github.com/nonebot/nonebot2.git
				synced 2025-11-04 00:46:43 +00:00 
			
		
		
		
	add startup shutdown deco
This commit is contained in:
		@@ -95,11 +95,12 @@ class BaseMessageSegment(abc.ABC):
 | 
				
			|||||||
class BaseMessage(list, abc.ABC):
 | 
					class BaseMessage(list, abc.ABC):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(self,
 | 
					    def __init__(self,
 | 
				
			||||||
                 message: Union[str, BaseMessageSegment, "BaseMessage"] = None,
 | 
					                 message: Union[str, dict, list, BaseMessageSegment,
 | 
				
			||||||
 | 
					                                "BaseMessage"] = None,
 | 
				
			||||||
                 *args,
 | 
					                 *args,
 | 
				
			||||||
                 **kwargs):
 | 
					                 **kwargs):
 | 
				
			||||||
        super().__init__(*args, **kwargs)
 | 
					        super().__init__(*args, **kwargs)
 | 
				
			||||||
        if isinstance(message, str):
 | 
					        if isinstance(message, (str, dict, list)):
 | 
				
			||||||
            self.extend(self._construct(message))
 | 
					            self.extend(self._construct(message))
 | 
				
			||||||
        elif isinstance(message, BaseMessage):
 | 
					        elif isinstance(message, BaseMessage):
 | 
				
			||||||
            self.extend(message)
 | 
					            self.extend(message)
 | 
				
			||||||
@@ -111,7 +112,7 @@ class BaseMessage(list, abc.ABC):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    @staticmethod
 | 
					    @staticmethod
 | 
				
			||||||
    @abc.abstractmethod
 | 
					    @abc.abstractmethod
 | 
				
			||||||
    def _construct(msg: str) -> Iterable[BaseMessageSegment]:
 | 
					    def _construct(msg: Union[str, dict, list]) -> Iterable[BaseMessageSegment]:
 | 
				
			||||||
        raise NotImplementedError
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __add__(
 | 
					    def __add__(
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -5,12 +5,11 @@ import re
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import httpx
 | 
					import httpx
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# from nonebot.event import Event
 | 
					 | 
				
			||||||
from nonebot.config import Config
 | 
					from nonebot.config import Config
 | 
				
			||||||
from nonebot.message import handle_event
 | 
					from nonebot.message import handle_event
 | 
				
			||||||
from nonebot.exception import ApiNotAvailable
 | 
					from nonebot.exception import ApiNotAvailable
 | 
				
			||||||
from nonebot.typing import Tuple, Iterable, Optional, overrides, WebSocket
 | 
					 | 
				
			||||||
from nonebot.adapters import BaseBot, BaseEvent, BaseMessage, BaseMessageSegment
 | 
					from nonebot.adapters import BaseBot, BaseEvent, BaseMessage, BaseMessageSegment
 | 
				
			||||||
 | 
					from nonebot.typing import Union, Tuple, Iterable, Optional, overrides, WebSocket
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def escape(s: str, *, escape_comma: bool = True) -> str:
 | 
					def escape(s: str, *, escape_comma: bool = True) -> str:
 | 
				
			||||||
@@ -98,6 +97,10 @@ class Bot(BaseBot):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
class Event(BaseEvent):
 | 
					class Event(BaseEvent):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, raw_event: dict):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        super().__init__(raw_event)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
    @overrides(BaseEvent)
 | 
					    @overrides(BaseEvent)
 | 
				
			||||||
    def type(self):
 | 
					    def type(self):
 | 
				
			||||||
@@ -286,7 +289,14 @@ class Message(BaseMessage):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    @staticmethod
 | 
					    @staticmethod
 | 
				
			||||||
    @overrides(BaseMessage)
 | 
					    @overrides(BaseMessage)
 | 
				
			||||||
    def _construct(msg: str) -> Iterable[MessageSegment]:
 | 
					    def _construct(msg: Union[str, dict, list]) -> Iterable[MessageSegment]:
 | 
				
			||||||
 | 
					        if isinstance(msg, dict):
 | 
				
			||||||
 | 
					            yield MessageSegment(msg["type"], msg.get("data") or {})
 | 
				
			||||||
 | 
					            return
 | 
				
			||||||
 | 
					        elif isinstance(msg, list):
 | 
				
			||||||
 | 
					            for seg in msg:
 | 
				
			||||||
 | 
					                yield MessageSegment(seg["type"], seg.get("data") or {})
 | 
				
			||||||
 | 
					            return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        def _iter_message() -> Iterable[Tuple[str, str]]:
 | 
					        def _iter_message() -> Iterable[Tuple[str, str]]:
 | 
				
			||||||
            text_begin = 0
 | 
					            text_begin = 0
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -5,7 +5,7 @@ import abc
 | 
				
			|||||||
from ipaddress import IPv4Address
 | 
					from ipaddress import IPv4Address
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from nonebot.config import Env, Config
 | 
					from nonebot.config import Env, Config
 | 
				
			||||||
from nonebot.typing import Bot, Dict, Optional
 | 
					from nonebot.typing import Bot, Dict, Optional, Callable
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class BaseDriver(abc.ABC):
 | 
					class BaseDriver(abc.ABC):
 | 
				
			||||||
@@ -35,6 +35,14 @@ class BaseDriver(abc.ABC):
 | 
				
			|||||||
    def bots(self) -> Dict[int, Bot]:
 | 
					    def bots(self) -> Dict[int, Bot]:
 | 
				
			||||||
        return self._clients
 | 
					        return self._clients
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @abc.abstractmethod
 | 
				
			||||||
 | 
					    def on_startup(self, func: Callable) -> Callable:
 | 
				
			||||||
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @abc.abstractmethod
 | 
				
			||||||
 | 
					    def on_shutdown(self, func: Callable) -> Callable:
 | 
				
			||||||
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @abc.abstractmethod
 | 
					    @abc.abstractmethod
 | 
				
			||||||
    def run(self,
 | 
					    def run(self,
 | 
				
			||||||
            host: Optional[IPv4Address] = None,
 | 
					            host: Optional[IPv4Address] = None,
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -14,9 +14,9 @@ from fastapi import Body, Header, Response, WebSocket as FastAPIWebSocket
 | 
				
			|||||||
from nonebot.log import logger
 | 
					from nonebot.log import logger
 | 
				
			||||||
from nonebot.config import Env, Config
 | 
					from nonebot.config import Env, Config
 | 
				
			||||||
from nonebot.utils import DataclassEncoder
 | 
					from nonebot.utils import DataclassEncoder
 | 
				
			||||||
from nonebot.typing import Optional, overrides
 | 
					 | 
				
			||||||
from nonebot.adapters.cqhttp import Bot as CQBot
 | 
					from nonebot.adapters.cqhttp import Bot as CQBot
 | 
				
			||||||
from nonebot.drivers import BaseDriver, BaseWebSocket
 | 
					from nonebot.drivers import BaseDriver, BaseWebSocket
 | 
				
			||||||
 | 
					from nonebot.typing import Optional, Callable, overrides
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Driver(BaseDriver):
 | 
					class Driver(BaseDriver):
 | 
				
			||||||
@@ -38,7 +38,7 @@ class Driver(BaseDriver):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
    @overrides(BaseDriver)
 | 
					    @overrides(BaseDriver)
 | 
				
			||||||
    def server_app(self):
 | 
					    def server_app(self) -> FastAPI:
 | 
				
			||||||
        return self._server_app
 | 
					        return self._server_app
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
@@ -51,6 +51,14 @@ class Driver(BaseDriver):
 | 
				
			|||||||
    def logger(self) -> logging.Logger:
 | 
					    def logger(self) -> logging.Logger:
 | 
				
			||||||
        return logging.getLogger("fastapi")
 | 
					        return logging.getLogger("fastapi")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @overrides(BaseDriver)
 | 
				
			||||||
 | 
					    def on_startup(self, func: Callable) -> Callable:
 | 
				
			||||||
 | 
					        return self.server_app.on_event("startup")(func)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @overrides(BaseDriver)
 | 
				
			||||||
 | 
					    def on_shutdown(self, func: Callable) -> Callable:
 | 
				
			||||||
 | 
					        return self.server_app.on_event("shutdown")(func)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @overrides(BaseDriver)
 | 
					    @overrides(BaseDriver)
 | 
				
			||||||
    def run(self,
 | 
					    def run(self,
 | 
				
			||||||
            host: Optional[IPv4Address] = None,
 | 
					            host: Optional[IPv4Address] = None,
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -18,7 +18,6 @@ def event_preprocessor(func: PreProcessor) -> PreProcessor:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def handle_event(bot: Bot, event: Event):
 | 
					async def handle_event(bot: Bot, event: Event):
 | 
				
			||||||
    # TODO: PreProcess
 | 
					 | 
				
			||||||
    coros = []
 | 
					    coros = []
 | 
				
			||||||
    for preprocessor in _event_preprocessors:
 | 
					    for preprocessor in _event_preprocessors:
 | 
				
			||||||
        coros.append(preprocessor(bot, event))
 | 
					        coros.append(preprocessor(bot, event))
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user