add startup shutdown deco

This commit is contained in:
yanyongyu 2020-08-11 10:44:05 +08:00
parent 2d90c35df6
commit b32d4a24d1
5 changed files with 36 additions and 10 deletions

View File

@ -95,11 +95,12 @@ class BaseMessageSegment(abc.ABC):
class BaseMessage(list, abc.ABC): class BaseMessage(list, abc.ABC):
def __init__(self, def __init__(self,
message: Union[str, BaseMessageSegment, "BaseMessage"] = None, message: Union[str, dict, list, BaseMessageSegment,
"BaseMessage"] = None,
*args, *args,
**kwargs): **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
if isinstance(message, str): if isinstance(message, (str, dict, list)):
self.extend(self._construct(message)) self.extend(self._construct(message))
elif isinstance(message, BaseMessage): elif isinstance(message, BaseMessage):
self.extend(message) self.extend(message)
@ -111,7 +112,7 @@ class BaseMessage(list, abc.ABC):
@staticmethod @staticmethod
@abc.abstractmethod @abc.abstractmethod
def _construct(msg: str) -> Iterable[BaseMessageSegment]: def _construct(msg: Union[str, dict, list]) -> Iterable[BaseMessageSegment]:
raise NotImplementedError raise NotImplementedError
def __add__( def __add__(

View File

@ -5,12 +5,11 @@ import re
import httpx import httpx
# from nonebot.event import Event
from nonebot.config import Config from nonebot.config import Config
from nonebot.message import handle_event from nonebot.message import handle_event
from nonebot.exception import ApiNotAvailable from nonebot.exception import ApiNotAvailable
from nonebot.typing import Tuple, Iterable, Optional, overrides, WebSocket
from nonebot.adapters import BaseBot, BaseEvent, BaseMessage, BaseMessageSegment from nonebot.adapters import BaseBot, BaseEvent, BaseMessage, BaseMessageSegment
from nonebot.typing import Union, Tuple, Iterable, Optional, overrides, WebSocket
def escape(s: str, *, escape_comma: bool = True) -> str: def escape(s: str, *, escape_comma: bool = True) -> str:
@ -98,6 +97,10 @@ class Bot(BaseBot):
class Event(BaseEvent): class Event(BaseEvent):
def __init__(self, raw_event: dict):
super().__init__(raw_event)
@property @property
@overrides(BaseEvent) @overrides(BaseEvent)
def type(self): def type(self):
@ -286,7 +289,14 @@ class Message(BaseMessage):
@staticmethod @staticmethod
@overrides(BaseMessage) @overrides(BaseMessage)
def _construct(msg: str) -> Iterable[MessageSegment]: def _construct(msg: Union[str, dict, list]) -> Iterable[MessageSegment]:
if isinstance(msg, dict):
yield MessageSegment(msg["type"], msg.get("data") or {})
return
elif isinstance(msg, list):
for seg in msg:
yield MessageSegment(seg["type"], seg.get("data") or {})
return
def _iter_message() -> Iterable[Tuple[str, str]]: def _iter_message() -> Iterable[Tuple[str, str]]:
text_begin = 0 text_begin = 0

View File

@ -5,7 +5,7 @@ import abc
from ipaddress import IPv4Address from ipaddress import IPv4Address
from nonebot.config import Env, Config from nonebot.config import Env, Config
from nonebot.typing import Bot, Dict, Optional from nonebot.typing import Bot, Dict, Optional, Callable
class BaseDriver(abc.ABC): class BaseDriver(abc.ABC):
@ -35,6 +35,14 @@ class BaseDriver(abc.ABC):
def bots(self) -> Dict[int, Bot]: def bots(self) -> Dict[int, Bot]:
return self._clients return self._clients
@abc.abstractmethod
def on_startup(self, func: Callable) -> Callable:
raise NotImplementedError
@abc.abstractmethod
def on_shutdown(self, func: Callable) -> Callable:
raise NotImplementedError
@abc.abstractmethod @abc.abstractmethod
def run(self, def run(self,
host: Optional[IPv4Address] = None, host: Optional[IPv4Address] = None,

View File

@ -14,9 +14,9 @@ from fastapi import Body, Header, Response, WebSocket as FastAPIWebSocket
from nonebot.log import logger from nonebot.log import logger
from nonebot.config import Env, Config from nonebot.config import Env, Config
from nonebot.utils import DataclassEncoder from nonebot.utils import DataclassEncoder
from nonebot.typing import Optional, overrides
from nonebot.adapters.cqhttp import Bot as CQBot from nonebot.adapters.cqhttp import Bot as CQBot
from nonebot.drivers import BaseDriver, BaseWebSocket from nonebot.drivers import BaseDriver, BaseWebSocket
from nonebot.typing import Optional, Callable, overrides
class Driver(BaseDriver): class Driver(BaseDriver):
@ -38,7 +38,7 @@ class Driver(BaseDriver):
@property @property
@overrides(BaseDriver) @overrides(BaseDriver)
def server_app(self): def server_app(self) -> FastAPI:
return self._server_app return self._server_app
@property @property
@ -51,6 +51,14 @@ class Driver(BaseDriver):
def logger(self) -> logging.Logger: def logger(self) -> logging.Logger:
return logging.getLogger("fastapi") return logging.getLogger("fastapi")
@overrides(BaseDriver)
def on_startup(self, func: Callable) -> Callable:
return self.server_app.on_event("startup")(func)
@overrides(BaseDriver)
def on_shutdown(self, func: Callable) -> Callable:
return self.server_app.on_event("shutdown")(func)
@overrides(BaseDriver) @overrides(BaseDriver)
def run(self, def run(self,
host: Optional[IPv4Address] = None, host: Optional[IPv4Address] = None,

View File

@ -18,7 +18,6 @@ def event_preprocessor(func: PreProcessor) -> PreProcessor:
async def handle_event(bot: Bot, event: Event): async def handle_event(bot: Bot, event: Event):
# TODO: PreProcess
coros = [] coros = []
for preprocessor in _event_preprocessors: for preprocessor in _event_preprocessors:
coros.append(preprocessor(bot, event)) coros.append(preprocessor(bot, event))