diff --git a/app/routers.py b/app/routers.py index e92ad30..6a67867 100644 --- a/app/routers.py +++ b/app/routers.py @@ -14,6 +14,7 @@ TriggerDescribeIn, SendMessageResponse, SendMessageIn, + ActionTriggerIn, ) router = APIRouter() @@ -64,6 +65,14 @@ async def describe(body: TriggerDescribeIn): return {"trigger_id": trigger_id, "trigger_type": trigger_type} +@router.post("/action", response_model=TriggerResponse) +async def action(body: ActionTriggerIn): + trigger_id = body.trigger_id + trigger_type = TriggerType.action.value + + taskqueue.put(trigger_id, discord.trigger_action, **body.dict()) + return {"trigger_id": trigger_id, "trigger_type": trigger_type} + @router.post("/upload", response_model=UploadResponse) async def upload_attachment(file: UploadFile): if not file.content_type.startswith("image/"): diff --git a/app/schema.py b/app/schema.py index 2d2b9f3..77311dc 100644 --- a/app/schema.py +++ b/app/schema.py @@ -37,6 +37,10 @@ class TriggerResponse(BaseModel): trigger_id: str trigger_type: str = "" +class ActionTriggerIn(BaseModel): + trigger_id: str = "" + custom_id: str = "" + msg_id: str = "" class UploadResponse(BaseModel): message: str = "success" diff --git a/lib/api/discord.py b/lib/api/discord.py index 7e8b37d..63da30c 100644 --- a/lib/api/discord.py +++ b/lib/api/discord.py @@ -23,6 +23,7 @@ class TriggerType(str, Enum): max_upscale = "max_upscale" reset = "reset" describe = "describe" + action = "action" async def trigger(payload: Dict[str, Any]): @@ -167,6 +168,17 @@ async def reset(msg_id: str, msg_hash: str, **kwargs): }, **kwargs) return await trigger(payload) +async def trigger_action(msg_id: str, custom_id: str, **kwargs): + kwargs = { + "message_flags": 0, + "message_id": msg_id, + } + payload = _trigger_payload(3, { + "component_type": 2, + "custom_id": custom_id + }, **kwargs) + return await trigger(payload) + async def describe(upload_filename: str, **kwargs): payload = _trigger_payload(2, { diff --git a/task/bot/_typing.py b/task/bot/_typing.py index 7ff8323..8128fb0 100644 --- a/task/bot/_typing.py +++ b/task/bot/_typing.py @@ -24,11 +24,16 @@ class Embed(TypedDict): image: EmbedsImage +class Action(TypedDict): + label: str + custom_id: str + + class CallbackData(TypedDict): type: str id: int content: str attachments: List[Attachment] embeds: List[Embed] - + actions: List[Action] trigger_id: str diff --git a/task/bot/handler.py b/task/bot/handler.py index 41c752c..737c31a 100644 --- a/task/bot/handler.py +++ b/task/bot/handler.py @@ -2,11 +2,11 @@ import re from typing import Dict, Union, Any -from discord import Message +from discord import Message, components, ActionRow from app.handler import PROMPT_PREFIX, PROMPT_SUFFIX from lib.api.callback import queue_release, callback -from task.bot._typing import CallbackData, Attachment, Embed +from task.bot._typing import CallbackData, Attachment, Embed, Action TRIGGER_ID_PATTERN = f"{PROMPT_PREFIX}(\w+?){PROMPT_SUFFIX}" # 消息 ID 正则 @@ -33,6 +33,16 @@ def match_trigger_id(content: str) -> Union[str, None]: match = re.findall(TRIGGER_ID_PATTERN, content) return match[0] if match else None +def get_action(message: Message) -> list[Action]: + res = [] + for component in message.components: + if isinstance(component, ActionRow): + for action in component.children: + if isinstance(action, components.Button) and action.label and action.custom_id: + res.append(Action(label = action.label, custom_id = action.custom_id)) + return res + + async def callback_trigger(trigger_id: str, trigger_status: str, message: Message): await callback(CallbackData( @@ -45,6 +55,7 @@ async def callback_trigger(trigger_id: str, trigger_status: str, message: Messag ], embeds=[], trigger_id=trigger_id, + actions=get_action(message) )) @@ -61,5 +72,6 @@ async def callback_describe(trigger_status: str, message: Message, embed: Dict[s Embed(**embed) ], trigger_id=trigger_id, + actions=get_action(message) )) return trigger_id diff --git a/util/_queue.py b/util/_queue.py index 1ed0639..9ef79c0 100644 --- a/util/_queue.py +++ b/util/_queue.py @@ -45,6 +45,7 @@ def put( self._wait_queue.append({ _trigger_id: Task(func, *args, **kwargs) }) + logger.debug(f"Task[{_trigger_id}] added to queue. Queue size: {len(self._wait_queue)}") while self._wait_queue and len(self._concur_queue) < self._concur_size: self._exec()