diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 8d3b40593..b160b3008 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -268,6 +268,14 @@ "misskey_default_visibility": "public", "misskey_local_only": False, "misskey_enable_chat": True, + # download / security options + "misskey_allow_insecure_downloads": False, + "misskey_download_timeout": 15, + "misskey_download_chunk_size": 65536, + "misskey_max_download_bytes": None, + "misskey_enable_file_upload": True, + "misskey_upload_concurrency": 3, + "misskey_upload_folder": "", }, "Slack": { "id": "slack", @@ -396,6 +404,41 @@ "type": "bool", "hint": "启用后,机器人将会监听和响应私信聊天消息", }, + "misskey_enable_file_upload": { + "description": "启用文件上传到 Misskey", + "type": "bool", + "hint": "启用后,适配器会尝试将消息链中的文件上传到 Misskey。URL 文件会先尝试服务器端上传,异步上传失败时会回退到下载后本地上传。", + }, + "misskey_allow_insecure_downloads": { + "description": "允许不安全下载(禁用 SSL 验证)", + "type": "bool", + "hint": "当远端服务器存在证书问题导致无法正常下载时,自动禁用 SSL 验证作为回退方案。适用于某些图床的证书配置问题。启用有安全风险,仅在必要时使用。", + }, + "misskey_download_timeout": { + "description": "远端下载超时时间(秒)", + "type": "int", + "hint": "下载远程文件时的超时时间(秒),用于异步上传回退到本地上传的场景。", + }, + "misskey_download_chunk_size": { + "description": "流式下载分块大小(字节)", + "type": "int", + "hint": "流式下载和计算 MD5 时使用的每次读取字节数,过小会增加开销,过大会占用内存。", + }, + "misskey_max_download_bytes": { + "description": "最大允许下载字节数(超出则中止)", + "type": "int", + "hint": "如果希望限制下载文件的最大大小以防止 OOM,请填写最大字节数;留空或 null 表示不限制。", + }, + "misskey_upload_concurrency": { + "description": "并发上传限制", + "type": "int", + "hint": "同时进行的文件上传任务上限(整数,默认 3)。", + }, + "misskey_upload_folder": { + "description": "上传到网盘的目标文件夹 ID", + "type": "string", + "hint": "可选:填写 Misskey 网盘中目标文件夹的 ID,上传的文件将放置到该文件夹内。留空则使用账号网盘根目录。", + }, "telegram_command_register": { "description": "Telegram 命令注册", "type": "bool", diff --git a/astrbot/core/platform/sources/misskey/misskey_adapter.py b/astrbot/core/platform/sources/misskey/misskey_adapter.py index 84608b54a..8c7f1b42f 100644 --- a/astrbot/core/platform/sources/misskey/misskey_adapter.py +++ b/astrbot/core/platform/sources/misskey/misskey_adapter.py @@ -1,6 +1,7 @@ import asyncio +import random import json -from typing import Dict, Any, Optional, Awaitable +from typing import Dict, Any, Optional, Awaitable, List from astrbot.api import logger from astrbot.api.event import MessageChain @@ -14,6 +15,13 @@ import astrbot.api.message_components as Comp from .misskey_api import MisskeyAPI +import os + +try: + import magic # type: ignore +except Exception: + magic = None + from .misskey_event import MisskeyPlatformEvent from .misskey_utils import ( serialize_message_chain, @@ -25,9 +33,15 @@ extract_sender_info, create_base_message, process_at_mention, + format_poll, cache_user_info, cache_room_info, ) +from astrbot.core.utils.astrbot_path import get_astrbot_data_path + +# Constants +MAX_FILE_UPLOAD_COUNT = 16 +DEFAULT_UPLOAD_CONCURRENCY = 3 @register_platform_adapter("misskey", "Misskey 平台适配器") @@ -46,6 +60,31 @@ def __init__( ) self.local_only = self.config.get("misskey_local_only", False) self.enable_chat = self.config.get("misskey_enable_chat", True) + self.enable_file_upload = self.config.get("misskey_enable_file_upload", True) + self.upload_folder = self.config.get("misskey_upload_folder") + + # download / security related options (exposed to platform_config) + self.allow_insecure_downloads = bool( + self.config.get("misskey_allow_insecure_downloads", False) + ) + # parse download timeout and chunk size safely + _dt = self.config.get("misskey_download_timeout") + try: + self.download_timeout = int(_dt) if _dt is not None else 15 + except Exception: + self.download_timeout = 15 + + _chunk = self.config.get("misskey_download_chunk_size") + try: + self.download_chunk_size = int(_chunk) if _chunk is not None else 64 * 1024 + except Exception: + self.download_chunk_size = 64 * 1024 + # parse max download bytes safely + _md_bytes = self.config.get("misskey_max_download_bytes") + try: + self.max_download_bytes = int(_md_bytes) if _md_bytes is not None else None + except Exception: + self.max_download_bytes = None self.unique_session = platform_settings["unique_session"] @@ -63,6 +102,11 @@ def meta(self) -> PlatformMetadata: "misskey_default_visibility": "public", "misskey_local_only": False, "misskey_enable_chat": True, + # download / security options + "misskey_allow_insecure_downloads": False, + "misskey_download_timeout": 15, + "misskey_download_chunk_size": 65536, + "misskey_max_download_bytes": None, } default_config.update(self.config) @@ -78,7 +122,14 @@ async def run(self): logger.error("[Misskey] 配置不完整,无法启动") return - self.api = MisskeyAPI(self.instance_url, self.access_token) + self.api = MisskeyAPI( + self.instance_url, + self.access_token, + allow_insecure_downloads=self.allow_insecure_downloads, + download_timeout=self.download_timeout, + chunk_size=self.download_chunk_size, + max_download_bytes=self.max_download_bytes, + ) self._running = True try: @@ -95,6 +146,80 @@ async def run(self): await self._start_websocket_connection() + def _register_event_handlers(self, streaming): + """注册事件处理器""" + streaming.add_message_handler("notification", self._handle_notification) + streaming.add_message_handler("main:notification", self._handle_notification) + + if self.enable_chat: + streaming.add_message_handler("newChatMessage", self._handle_chat_message) + streaming.add_message_handler( + "messaging:newChatMessage", self._handle_chat_message + ) + streaming.add_message_handler("_debug", self._debug_handler) + + async def _send_text_only_message( + self, session_id: str, text: str, session, message_chain + ): + """发送纯文本消息(无文件上传)""" + if not self.api: + return await super().send_by_session(session, message_chain) + + if session_id and is_valid_user_session_id(session_id): + from .misskey_utils import extract_user_id_from_session_id + + user_id = extract_user_id_from_session_id(session_id) + payload: Dict[str, Any] = {"toUserId": user_id, "text": text} + await self.api.send_message(payload) + elif session_id and is_valid_room_session_id(session_id): + from .misskey_utils import extract_room_id_from_session_id + + room_id = extract_room_id_from_session_id(session_id) + payload = {"toRoomId": room_id, "text": text} + await self.api.send_room_message(payload) + + return await super().send_by_session(session, message_chain) + + def _process_poll_data( + self, message: AstrBotMessage, poll: Dict[str, Any], message_parts: List[str] + ): + """处理投票数据,将其添加到消息中""" + try: + if not isinstance(message.raw_message, dict): + message.raw_message = {} + message.raw_message["poll"] = poll + setattr(message, "poll", poll) + except Exception: + pass + + poll_text = format_poll(poll) + if poll_text: + message.message.append(Comp.Plain(poll_text)) + message_parts.append(poll_text) + + def _extract_additional_fields(self, session, message_chain) -> Dict[str, Any]: + """从会话和消息链中提取额外字段""" + fields = {"cw": None, "poll": None, "renote_id": None, "channel_id": None} + + for comp in message_chain.chain: + if hasattr(comp, "cw") and getattr(comp, "cw", None): + fields["cw"] = getattr(comp, "cw") + break + + if hasattr(session, "extra_data") and isinstance( + getattr(session, "extra_data", None), dict + ): + extra_data = getattr(session, "extra_data") + fields.update( + { + "poll": extra_data.get("poll"), + "renote_id": extra_data.get("renote_id"), + "channel_id": extra_data.get("channel_id"), + } + ) + + return fields + async def _start_websocket_connection(self): backoff_delay = 1.0 max_backoff = 300.0 @@ -109,25 +234,20 @@ async def _start_websocket_connection(self): break streaming = self.api.get_streaming_client() - streaming.add_message_handler("notification", self._handle_notification) - if self.enable_chat: - streaming.add_message_handler( - "newChatMessage", self._handle_chat_message - ) - streaming.add_message_handler("_debug", self._debug_handler) + self._register_event_handlers(streaming) if await streaming.connect(): logger.info( f"[Misskey] WebSocket 已连接 (尝试 #{connection_attempts})" ) - connection_attempts = 0 # 重置计数器 + connection_attempts = 0 await streaming.subscribe_channel("main") if self.enable_chat: await streaming.subscribe_channel("messaging") await streaming.subscribe_channel("messagingIndex") logger.info("[Misskey] 聊天频道已订阅") - backoff_delay = 1.0 # 重置延迟 + backoff_delay = 1.0 await streaming.listen() else: logger.error( @@ -140,10 +260,12 @@ async def _start_websocket_connection(self): ) if self._running: + jitter = random.uniform(0, 1.0) + sleep_time = backoff_delay + jitter logger.info( - f"[Misskey] {backoff_delay:.1f}秒后重连 (下次尝试 #{connection_attempts + 1})" + f"[Misskey] {sleep_time:.1f}秒后重连 (下次尝试 #{connection_attempts + 1})" ) - await asyncio.sleep(backoff_delay) + await asyncio.sleep(sleep_time) backoff_delay = min(backoff_delay * backoff_multiplier, max_backoff) async def _handle_notification(self, data: Dict[str, Any]): @@ -164,7 +286,7 @@ async def _handle_notification(self, data: Dict[str, Any]): message_obj=message, platform_meta=self.meta(), session_id=message.session_id, - client=self.api, + client=self, ) self.commit_event(event) except Exception as e: @@ -200,7 +322,7 @@ async def _handle_chat_message(self, data: Dict[str, Any]): message_obj=message, platform_meta=self.meta(), session_id=message.session_id, - client=self.api, + client=self, ) self.commit_event(event) except Exception as e: @@ -239,43 +361,231 @@ async def send_by_session( try: session_id = session.session_id + text, has_at_user = serialize_message_chain(message_chain.chain) if not has_at_user and session_id: user_info = self._user_cache.get(session_id) text = add_at_mention_if_needed(text, user_info, has_at_user) + # 检查是否有文件组件 + has_file_components = any( + isinstance(comp, Comp.Image) + or isinstance(comp, Comp.File) + or hasattr(comp, "convert_to_file_path") + or hasattr(comp, "get_file") + or any( + hasattr(comp, a) for a in ("file", "url", "path", "src", "source") + ) + for comp in message_chain.chain + ) + if not text or not text.strip(): - logger.warning("[Misskey] 消息内容为空,跳过发送") - return await super().send_by_session(session, message_chain) + if not has_file_components: + logger.warning("[Misskey] 消息内容为空且无文件组件,跳过发送") + return await super().send_by_session(session, message_chain) + else: + text = "" if len(text) > self.max_message_length: text = text[: self.max_message_length] + "..." - if session_id and is_valid_user_session_id(session_id): - from .misskey_utils import extract_user_id_from_session_id + file_ids: List[str] = [] + fallback_urls: List[str] = [] + + if not self.enable_file_upload: + return await self._send_text_only_message( + session_id, text, session, message_chain + ) + + MAX_UPLOAD_CONCURRENCY = 10 + upload_concurrency = int( + self.config.get( + "misskey_upload_concurrency", DEFAULT_UPLOAD_CONCURRENCY + ) + ) + upload_concurrency = min(upload_concurrency, MAX_UPLOAD_CONCURRENCY) + sem = asyncio.Semaphore(upload_concurrency) + + async def _upload_comp(comp) -> Optional[object]: + """组件上传函数:处理 URL(下载后上传)或本地文件(直接上传)""" + from .misskey_utils import ( + resolve_component_url_or_path, + upload_local_with_retries, + ) + + local_path = None + try: + async with sem: + if not self.api: + return None + + # 解析组件的 URL 或本地路径 + url_candidate, local_path = await resolve_component_url_or_path( + comp + ) + + if not url_candidate and not local_path: + return None + + preferred_name = getattr(comp, "name", None) or getattr( + comp, "file", None + ) + + # URL 上传:下载后本地上传 + if url_candidate: + result = await self.api.upload_and_find_file( + str(url_candidate), + preferred_name, + folder_id=self.upload_folder, + ) + if isinstance(result, dict) and result.get("id"): + return str(result["id"]) + + # 本地文件上传 + if local_path: + file_id = await upload_local_with_retries( + self.api, + str(local_path), + preferred_name, + self.upload_folder, + ) + if file_id: + return file_id + + # 所有上传都失败,尝试获取 URL 作为回退 + if hasattr(comp, "register_to_file_service"): + try: + url = await comp.register_to_file_service() + if url: + return {"fallback_url": url} + except Exception: + pass + + return None + + finally: + # 清理临时文件 + if local_path and isinstance(local_path, str): + data_temp = os.path.join(get_astrbot_data_path(), "temp") + if local_path.startswith(data_temp) and os.path.exists( + local_path + ): + try: + os.remove(local_path) + logger.debug(f"[Misskey] 已清理临时文件: {local_path}") + except Exception: + pass + + # 收集所有可能包含文件/URL信息的组件:支持异步接口或同步字段 + file_components = [] + for comp in message_chain.chain: + try: + if ( + isinstance(comp, Comp.Image) + or isinstance(comp, Comp.File) + or hasattr(comp, "convert_to_file_path") + or hasattr(comp, "get_file") + or any( + hasattr(comp, a) + for a in ("file", "url", "path", "src", "source") + ) + ): + file_components.append(comp) + except Exception: + # 保守跳过无法访问属性的组件 + continue + + if len(file_components) > MAX_FILE_UPLOAD_COUNT: + logger.warning( + f"[Misskey] 文件数量超过限制 ({len(file_components)} > {MAX_FILE_UPLOAD_COUNT}),只上传前{MAX_FILE_UPLOAD_COUNT}个文件" + ) + file_components = file_components[:MAX_FILE_UPLOAD_COUNT] - user_id = extract_user_id_from_session_id(session_id) - await self.api.send_message(user_id, text) - elif session_id and is_valid_room_session_id(session_id): + upload_tasks = [_upload_comp(comp) for comp in file_components] + + try: + results = await asyncio.gather(*upload_tasks) if upload_tasks else [] + for r in results: + if not r: + continue + if isinstance(r, dict) and r.get("fallback_url"): + url = r.get("fallback_url") + if url: + fallback_urls.append(str(url)) + else: + try: + fid_str = str(r) + if fid_str: + file_ids.append(fid_str) + except Exception: + pass + except Exception: + logger.debug("[Misskey] 并发上传过程中出现异常,继续发送文本") + + if session_id and is_valid_room_session_id(session_id): from .misskey_utils import extract_room_id_from_session_id room_id = extract_room_id_from_session_id(session_id) - await self.api.send_room_message(room_id, text) - else: - visibility, visible_user_ids = resolve_message_visibility( - user_id=session_id, - user_cache=self._user_cache, - self_id=self.client_self_id, - default_visibility=self.default_visibility, + if fallback_urls: + appended = "\n" + "\n".join(fallback_urls) + text = (text or "") + appended + payload: Dict[str, Any] = {"toRoomId": room_id, "text": text} + if file_ids: + payload["fileIds"] = file_ids + await self.api.send_room_message(payload) + elif session_id: + from .misskey_utils import ( + extract_user_id_from_session_id, + is_valid_chat_session_id, ) - await self.api.create_note( - text, - visibility=visibility, - visible_user_ids=visible_user_ids, - local_only=self.local_only, - ) + if is_valid_chat_session_id(session_id): + user_id = extract_user_id_from_session_id(session_id) + if fallback_urls: + appended = "\n" + "\n".join(fallback_urls) + text = (text or "") + appended + payload: Dict[str, Any] = {"toUserId": user_id, "text": text} + if file_ids: + # 聊天消息只支持单个文件,使用 fileId 而不是 fileIds + payload["fileId"] = file_ids[0] + if len(file_ids) > 1: + logger.warning( + f"[Misskey] 聊天消息只支持单个文件,忽略其余 {len(file_ids) - 1} 个文件" + ) + await self.api.send_message(payload) + else: + # 回退到发帖逻辑 + # 去掉 session_id 中的 note% 前缀以匹配 user_cache 的键格式 + user_id_for_cache = ( + session_id.split("%")[1] if "%" in session_id else session_id + ) + visibility, visible_user_ids = resolve_message_visibility( + user_id=user_id_for_cache, + user_cache=self._user_cache, + self_id=self.client_self_id, + default_visibility=self.default_visibility, + ) + logger.debug( + f"[Misskey] 解析可见性: visibility={visibility}, visible_user_ids={visible_user_ids}, session_id={session_id}, user_id_for_cache={user_id_for_cache}" + ) + + fields = self._extract_additional_fields(session, message_chain) + if fallback_urls: + appended = "\n" + "\n".join(fallback_urls) + text = (text or "") + appended + + await self.api.create_note( + text=text, + visibility=visibility, + visible_user_ids=visible_user_ids, + file_ids=file_ids or None, + local_only=self.local_only, + cw=fields["cw"], + poll=fields["poll"], + renote_id=fields["renote_id"], + channel_id=fields["channel_id"], + ) except Exception as e: logger.error(f"[Misskey] 发送消息失败: {e}") @@ -309,6 +619,14 @@ async def convert_message(self, raw_data: Dict[str, Any]) -> AstrBotMessage: file_parts = process_files(message, files) message_parts.extend(file_parts) + poll = raw_data.get("poll") or ( + raw_data.get("note", {}).get("poll") + if isinstance(raw_data.get("note"), dict) + else None + ) + if poll and isinstance(poll, dict): + self._process_poll_data(message, poll, message_parts) + message.message_str = ( " ".join(part for part in message_parts if part.strip()) if message_parts diff --git a/astrbot/core/platform/sources/misskey/misskey_api.py b/astrbot/core/platform/sources/misskey/misskey_api.py index dc4adcdd0..0c3334ef6 100644 --- a/astrbot/core/platform/sources/misskey/misskey_api.py +++ b/astrbot/core/platform/sources/misskey/misskey_api.py @@ -1,4 +1,6 @@ import json +import random +import asyncio from typing import Any, Optional, Dict, List, Callable, Awaitable import uuid @@ -11,6 +13,7 @@ ) from e from astrbot.api import logger +from .misskey_utils import FileIDExtractor # Constants API_MAX_RETRIES = 3 @@ -55,6 +58,7 @@ def __init__(self, instance_url: str, access_token: str): self.is_connected = False self.message_handlers: Dict[str, Callable] = {} self.channels: Dict[str, str] = {} + self.desired_channels: Dict[str, Optional[Dict]] = {} self._running = False self._last_pong = None @@ -72,6 +76,18 @@ async def connect(self) -> bool: self._running = True logger.info("[Misskey WebSocket] 已连接") + if self.desired_channels: + try: + desired = list(self.desired_channels.items()) + for channel_type, params in desired: + try: + await self.subscribe_channel(channel_type, params) + except Exception as e: + logger.warning( + f"[Misskey WebSocket] 重新订阅 {channel_type} 失败: {e}" + ) + except Exception: + pass return True except Exception as e: @@ -112,9 +128,12 @@ async def unsubscribe_channel(self, channel_id: str): return message = {"type": "disconnect", "body": {"id": channel_id}} - await self.websocket.send(json.dumps(message)) - del self.channels[channel_id] + channel_type = self.channels.get(channel_id) + if channel_id in self.channels: + del self.channels[channel_id] + if channel_type and channel_type not in self.channels.values(): + self.desired_channels.pop(channel_type, None) def add_message_handler( self, event_type: str, handler: Callable[[Dict], Awaitable[None]] @@ -141,24 +160,70 @@ async def listen(self): except websockets.exceptions.ConnectionClosedError as e: logger.warning(f"[Misskey WebSocket] 连接意外关闭: {e}") self.is_connected = False + try: + await self.disconnect() + except Exception: + pass except websockets.exceptions.ConnectionClosed as e: logger.warning( f"[Misskey WebSocket] 连接已关闭 (代码: {e.code}, 原因: {e.reason})" ) self.is_connected = False + try: + await self.disconnect() + except Exception: + pass except websockets.exceptions.InvalidHandshake as e: logger.error(f"[Misskey WebSocket] 握手失败: {e}") self.is_connected = False + try: + await self.disconnect() + except Exception: + pass except Exception as e: logger.error(f"[Misskey WebSocket] 监听消息失败: {e}") self.is_connected = False + try: + await self.disconnect() + except Exception: + pass async def _handle_message(self, data: Dict[str, Any]): message_type = data.get("type") body = data.get("body", {}) + def _build_channel_summary(message_type: Optional[str], body: Any) -> str: + try: + if not isinstance(body, dict): + return f"[Misskey WebSocket] 收到消息类型: {message_type}" + + inner = body.get("body") if isinstance(body.get("body"), dict) else body + note = ( + inner.get("note") + if isinstance(inner, dict) and isinstance(inner.get("note"), dict) + else None + ) + + text = note.get("text") if note else None + note_id = note.get("id") if note else None + files = note.get("files") or [] if note else [] + has_files = bool(files) + is_hidden = bool(note.get("isHidden")) if note else False + user = note.get("user", {}) if note else None + + return ( + f"[Misskey WebSocket] 收到消息类型: {message_type} | " + f"note_id={note_id} | user={user.get('username') if user else None} | " + f"text={text[:80] if text else '[no-text]'} | files={has_files} | hidden={is_hidden}" + ) + except Exception: + return f"[Misskey WebSocket] 收到消息类型: {message_type}" + + channel_summary = _build_channel_summary(message_type, body) + logger.info(channel_summary) + logger.debug( - f"[Misskey WebSocket] 收到消息类型: {message_type}\n数据: {json.dumps(data, indent=2, ensure_ascii=False)}" + f"[Misskey WebSocket] 收到完整消息: {json.dumps(data, indent=2, ensure_ascii=False)}" ) if message_type == "channel": @@ -202,16 +267,60 @@ async def _handle_message(self, data: Dict[str, Any]): await self.message_handlers["_debug"](data) -def retry_async(max_retries: int = 3, retryable_exceptions: tuple = ()): +def retry_async( + max_retries: int = 3, + retryable_exceptions: tuple = (APIConnectionError, APIRateLimitError), + backoff_base: float = 1.0, + max_backoff: float = 30.0, +): + """ + 智能异步重试装饰器 + + Args: + max_retries: 最大重试次数 + retryable_exceptions: 可重试的异常类型 + backoff_base: 退避基数 + max_backoff: 最大退避时间 + """ + def decorator(func): async def wrapper(*args, **kwargs): last_exc = None - for _ in range(max_retries): + func_name = getattr(func, "__name__", "unknown") + + for attempt in range(1, max_retries + 1): try: return await func(*args, **kwargs) except retryable_exceptions as e: last_exc = e + if attempt == max_retries: + logger.error( + f"[Misskey API] {func_name} 重试 {max_retries} 次后仍失败: {e}" + ) + break + + # 智能退避策略 + if isinstance(e, APIRateLimitError): + # 频率限制用更长的退避时间 + backoff = min(backoff_base * (3**attempt), max_backoff) + else: + # 其他错误用指数退避 + backoff = min(backoff_base * (2**attempt), max_backoff) + + jitter = random.uniform(0.1, 0.5) # 随机抖动 + sleep_time = backoff + jitter + + logger.warning( + f"[Misskey API] {func_name} 第 {attempt} 次重试失败: {e}," + f"{sleep_time:.1f}s后重试" + ) + await asyncio.sleep(sleep_time) continue + except Exception as e: + # 非可重试异常直接抛出 + logger.error(f"[Misskey API] {func_name} 遇到不可重试异常: {e}") + raise + if last_exc: raise last_exc @@ -221,11 +330,27 @@ async def wrapper(*args, **kwargs): class MisskeyAPI: - def __init__(self, instance_url: str, access_token: str): + def __init__( + self, + instance_url: str, + access_token: str, + *, + allow_insecure_downloads: bool = False, + download_timeout: int = 15, + chunk_size: int = 64 * 1024, + max_download_bytes: Optional[int] = None, + ): self.instance_url = instance_url.rstrip("/") self.access_token = access_token self._session: Optional[aiohttp.ClientSession] = None self.streaming: Optional[StreamingClient] = None + # download options + self.allow_insecure_downloads = allow_insecure_downloads + self.download_timeout = download_timeout + self.chunk_size = chunk_size + self.max_download_bytes = ( + int(max_download_bytes) if max_download_bytes is not None else None + ) async def __aenter__(self): return self @@ -258,16 +383,37 @@ def session(self) -> aiohttp.ClientSession: def _handle_response_status(self, status: int, endpoint: str): """处理 HTTP 响应状态码""" if status == 400: - logger.error(f"API 请求错误: {endpoint} (状态码: {status})") + logger.error(f"[Misskey API] 请求参数错误: {endpoint} (HTTP {status})") raise APIError(f"Bad request for {endpoint}") - elif status in (401, 403): - logger.error(f"API 认证失败: {endpoint} (状态码: {status})") - raise AuthenticationError(f"Authentication failed for {endpoint}") + elif status == 401: + logger.error(f"[Misskey API] 未授权访问: {endpoint} (HTTP {status})") + raise AuthenticationError(f"Unauthorized access for {endpoint}") + elif status == 403: + logger.error(f"[Misskey API] 访问被禁止: {endpoint} (HTTP {status})") + raise AuthenticationError(f"Forbidden access for {endpoint}") + elif status == 404: + logger.error(f"[Misskey API] 资源不存在: {endpoint} (HTTP {status})") + raise APIError(f"Resource not found for {endpoint}") + elif status == 413: + logger.error(f"[Misskey API] 请求体过大: {endpoint} (HTTP {status})") + raise APIError(f"Request entity too large for {endpoint}") elif status == 429: - logger.warning(f"API 频率限制: {endpoint} (状态码: {status})") + logger.warning(f"[Misskey API] 请求频率限制: {endpoint} (HTTP {status})") raise APIRateLimitError(f"Rate limit exceeded for {endpoint}") + elif status == 500: + logger.error(f"[Misskey API] 服务器内部错误: {endpoint} (HTTP {status})") + raise APIConnectionError(f"Internal server error for {endpoint}") + elif status == 502: + logger.error(f"[Misskey API] 网关错误: {endpoint} (HTTP {status})") + raise APIConnectionError(f"Bad gateway for {endpoint}") + elif status == 503: + logger.error(f"[Misskey API] 服务不可用: {endpoint} (HTTP {status})") + raise APIConnectionError(f"Service unavailable for {endpoint}") + elif status == 504: + logger.error(f"[Misskey API] 网关超时: {endpoint} (HTTP {status})") + raise APIConnectionError(f"Gateway timeout for {endpoint}") else: - logger.error(f"API 请求失败: {endpoint} (状态码: {status})") + logger.error(f"[Misskey API] 未知错误: {endpoint} (HTTP {status})") raise APIConnectionError(f"HTTP {status} for {endpoint}") async def _process_response( @@ -286,21 +432,25 @@ async def _process_response( else [] ) if notifications_data: - logger.debug(f"获取到 {len(notifications_data)} 条新通知") + logger.debug( + f"[Misskey API] 获取到 {len(notifications_data)} 条新通知" + ) else: - logger.debug(f"API 请求成功: {endpoint}") + logger.debug(f"[Misskey API] 请求成功: {endpoint}") return result except json.JSONDecodeError as e: - logger.error(f"响应不是有效的 JSON 格式: {e}") + logger.error(f"[Misskey API] 响应格式错误: {e}") raise APIConnectionError("Invalid JSON response") from e else: try: error_text = await response.text() logger.error( - f"API 请求失败: {endpoint} - 状态码: {response.status}, 响应: {error_text}" + f"[Misskey API] 请求失败: {endpoint} - HTTP {response.status}, 响应: {error_text}" ) except Exception: - logger.error(f"API 请求失败: {endpoint} - 状态码: {response.status}") + logger.error( + f"[Misskey API] 请求失败: {endpoint} - HTTP {response.status}" + ) self._handle_response_status(response.status, endpoint) raise APIConnectionError(f"Request failed for {endpoint}") @@ -321,53 +471,307 @@ async def _make_request( async with self.session.post(url, json=payload) as response: return await self._process_response(response, endpoint) except aiohttp.ClientError as e: - logger.error(f"HTTP 请求错误: {e}") + logger.error(f"[Misskey API] HTTP 请求错误: {e}") raise APIConnectionError(f"HTTP request failed: {e}") from e async def create_note( self, - text: str, + text: Optional[str] = None, visibility: str = "public", reply_id: Optional[str] = None, visible_user_ids: Optional[List[str]] = None, + file_ids: Optional[List[str]] = None, local_only: bool = False, + cw: Optional[str] = None, + poll: Optional[Dict[str, Any]] = None, + renote_id: Optional[str] = None, + channel_id: Optional[str] = None, + reaction_acceptance: Optional[str] = None, + no_extract_mentions: Optional[bool] = None, + no_extract_hashtags: Optional[bool] = None, + no_extract_emojis: Optional[bool] = None, + media_ids: Optional[List[str]] = None, ) -> Dict[str, Any]: - """创建新贴文""" - data: Dict[str, Any] = { - "text": text, - "visibility": visibility, - "localOnly": local_only, - } + """Create a note (wrapper for notes/create). All additional fields are optional and passed through to the API.""" + data: Dict[str, Any] = {} + + if text is not None: + data["text"] = text + + data["visibility"] = visibility + data["localOnly"] = local_only + if reply_id: data["replyId"] = reply_id + if visible_user_ids and visibility == "specified": data["visibleUserIds"] = visible_user_ids + if file_ids: + data["fileIds"] = file_ids + if media_ids: + data["mediaIds"] = media_ids + + if cw is not None: + data["cw"] = cw + if poll is not None: + data["poll"] = poll + if renote_id is not None: + data["renoteId"] = renote_id + if channel_id is not None: + data["channelId"] = channel_id + if reaction_acceptance is not None: + data["reactionAcceptance"] = reaction_acceptance + if no_extract_mentions is not None: + data["noExtractMentions"] = bool(no_extract_mentions) + if no_extract_hashtags is not None: + data["noExtractHashtags"] = bool(no_extract_hashtags) + if no_extract_emojis is not None: + data["noExtractEmojis"] = bool(no_extract_emojis) + result = await self._make_request("notes/create", data) - note_id = result.get("createdNote", {}).get("id", "unknown") - logger.debug(f"发帖成功,note_id: {note_id}") + note_id = ( + result.get("createdNote", {}).get("id", "unknown") + if isinstance(result, dict) + else "unknown" + ) + logger.debug(f"[Misskey API] 发帖成功: {note_id}") return result + async def upload_file( + self, + file_path: str, + name: Optional[str] = None, + folder_id: Optional[str] = None, + ) -> Dict[str, Any]: + """Upload a file to Misskey drive/files/create and return a dict containing id and raw result.""" + if not file_path: + raise APIError("No file path provided for upload") + + url = f"{self.instance_url}/api/drive/files/create" + form = aiohttp.FormData() + form.add_field("i", self.access_token) + + try: + filename = name or file_path.split("/")[-1] + if folder_id: + form.add_field("folderId", str(folder_id)) + + try: + f = open(file_path, "rb") + except FileNotFoundError as e: + logger.error(f"[Misskey API] 本地文件不存在: {file_path}") + raise APIError(f"File not found: {file_path}") from e + + try: + form.add_field("file", f, filename=filename) + async with self.session.post(url, data=form) as resp: + result = await self._process_response(resp, "drive/files/create") + file_id = FileIDExtractor.extract_file_id(result) + logger.debug( + f"[Misskey API] 本地文件上传成功: {filename} -> {file_id}" + ) + return {"id": file_id, "raw": result} + finally: + f.close() + except aiohttp.ClientError as e: + logger.error(f"[Misskey API] 文件上传网络错误: {e}") + raise APIConnectionError(f"Upload failed: {e}") from e + + async def find_files_by_hash(self, md5_hash: str) -> List[Dict[str, Any]]: + """Find files by MD5 hash""" + if not md5_hash: + raise APIError("No MD5 hash provided for find-by-hash") + + data = {"md5": md5_hash} + + try: + logger.debug(f"[Misskey API] find-by-hash 请求: md5={md5_hash}") + result = await self._make_request("drive/files/find-by-hash", data) + logger.debug( + f"[Misskey API] find-by-hash 响应: 找到 {len(result) if isinstance(result, list) else 0} 个文件" + ) + return result if isinstance(result, list) else [] + except Exception as e: + logger.error(f"[Misskey API] 根据哈希查找文件失败: {e}") + raise + + async def find_files_by_name( + self, name: str, folder_id: Optional[str] = None + ) -> List[Dict[str, Any]]: + """Find files by name""" + if not name: + raise APIError("No name provided for find") + + data: Dict[str, Any] = {"name": name} + if folder_id: + data["folderId"] = folder_id + + try: + logger.debug(f"[Misskey API] find 请求: name={name}, folder_id={folder_id}") + result = await self._make_request("drive/files/find", data) + logger.debug( + f"[Misskey API] find 响应: 找到 {len(result) if isinstance(result, list) else 0} 个文件" + ) + return result if isinstance(result, list) else [] + except Exception as e: + logger.error(f"[Misskey API] 根据名称查找文件失败: {e}") + raise + + async def find_files( + self, + limit: int = 10, + folder_id: Optional[str] = None, + type: Optional[str] = None, + ) -> List[Dict[str, Any]]: + """List files with optional filters""" + data: Dict[str, Any] = {"limit": limit} + if folder_id is not None: + data["folderId"] = folder_id + if type is not None: + data["type"] = type + + try: + logger.debug( + f"[Misskey API] 列表文件请求: limit={limit}, folder_id={folder_id}, type={type}" + ) + result = await self._make_request("drive/files", data) + logger.debug( + f"[Misskey API] 列表文件响应: 找到 {len(result) if isinstance(result, list) else 0} 个文件" + ) + return result if isinstance(result, list) else [] + except Exception as e: + logger.error(f"[Misskey API] 列表文件失败: {e}") + raise + + async def _download_with_existing_session( + self, url: str, ssl_verify: bool = True + ) -> Optional[bytes]: + """使用现有会话下载文件""" + if not (hasattr(self, "session") and self.session): + raise APIConnectionError("No existing session available") + + async with self.session.get( + url, timeout=aiohttp.ClientTimeout(total=15), ssl=ssl_verify + ) as response: + if response.status == 200: + return await response.read() + return None + + async def _download_with_temp_session( + self, url: str, ssl_verify: bool = True + ) -> Optional[bytes]: + """使用临时会话下载文件""" + connector = aiohttp.TCPConnector(ssl=ssl_verify) + async with aiohttp.ClientSession(connector=connector) as temp_session: + async with temp_session.get( + url, timeout=aiohttp.ClientTimeout(total=15) + ) as response: + if response.status == 200: + return await response.read() + return None + + async def upload_and_find_file( + self, + url: str, + name: Optional[str] = None, + folder_id: Optional[str] = None, + max_wait_time: float = 30.0, + check_interval: float = 2.0, + ) -> Optional[Dict[str, Any]]: + """ + 简化的文件上传:尝试 URL 上传,失败则下载后本地上传 + + Args: + url: 文件URL + name: 文件名(可选) + folder_id: 文件夹ID(可选) + max_wait_time: 保留参数(未使用) + check_interval: 保留参数(未使用) + + Returns: + 包含文件ID和元信息的字典,失败时返回None + """ + if not url: + raise APIError("URL不能为空") + + # 通过本地上传获取即时文件 ID(下载文件 → 上传 → 返回 ID) + try: + import tempfile + import os + + # SSL 验证下载,失败则重试不验证 SSL + tmp_bytes = None + try: + tmp_bytes = await self._download_with_existing_session( + url, ssl_verify=True + ) or await self._download_with_temp_session(url, ssl_verify=True) + except Exception as ssl_error: + logger.debug( + f"[Misskey API] SSL 验证下载失败: {ssl_error},重试不验证 SSL" + ) + try: + tmp_bytes = await self._download_with_existing_session( + url, ssl_verify=False + ) or await self._download_with_temp_session(url, ssl_verify=False) + except Exception: + pass + + if tmp_bytes: + with tempfile.NamedTemporaryFile(delete=False) as tmpf: + tmpf.write(tmp_bytes) + tmp_path = tmpf.name + + try: + result = await self.upload_file(tmp_path, name, folder_id) + logger.debug(f"[Misskey API] 本地上传成功: {result.get('id')}") + return result + finally: + try: + os.unlink(tmp_path) + except Exception: + pass + except Exception as e: + logger.error(f"[Misskey API] 本地上传失败: {e}") + + return None + async def get_current_user(self) -> Dict[str, Any]: """获取当前用户信息""" return await self._make_request("i", {}) - async def send_message(self, user_id: str, text: str) -> Dict[str, Any]: - """发送聊天消息""" - result = await self._make_request( - "chat/messages/create-to-user", {"toUserId": user_id, "text": text} - ) + async def send_message( + self, user_id_or_payload: Any, text: Optional[str] = None + ) -> Dict[str, Any]: + """发送聊天消息。 + + Accepts either (user_id: str, text: str) or a single dict payload prepared by caller. + """ + if isinstance(user_id_or_payload, dict): + data = user_id_or_payload + else: + data = {"toUserId": user_id_or_payload, "text": text} + + result = await self._make_request("chat/messages/create-to-user", data) message_id = result.get("id", "unknown") - logger.debug(f"聊天发送成功,message_id: {message_id}") + logger.debug(f"[Misskey API] 聊天消息发送成功: {message_id}") return result - async def send_room_message(self, room_id: str, text: str) -> Dict[str, Any]: - """发送房间消息""" - result = await self._make_request( - "chat/messages/create-to-room", {"toRoomId": room_id, "text": text} - ) + async def send_room_message( + self, room_id_or_payload: Any, text: Optional[str] = None + ) -> Dict[str, Any]: + """发送房间消息。 + + Accepts either (room_id: str, text: str) or a single dict payload. + """ + if isinstance(room_id_or_payload, dict): + data = room_id_or_payload + else: + data = {"toRoomId": room_id_or_payload, "text": text} + + result = await self._make_request("chat/messages/create-to-room", data) message_id = result.get("id", "unknown") - logger.debug(f"房间消息发送成功,message_id: {message_id}") + logger.debug(f"[Misskey API] 房间消息发送成功: {message_id}") return result async def get_messages( @@ -381,9 +785,8 @@ async def get_messages( result = await self._make_request("chat/messages/user-timeline", data) if isinstance(result, list): return result - else: - logger.warning(f"获取聊天消息响应格式异常: {type(result)}") - return [] + logger.warning(f"[Misskey API] 聊天消息响应格式异常: {type(result)}") + return [] async def get_mentions( self, limit: int = 10, since_id: Optional[str] = None @@ -400,5 +803,142 @@ async def get_mentions( elif isinstance(result, dict) and "notifications" in result: return result["notifications"] else: - logger.warning(f"获取提及通知响应格式异常: {type(result)}") + logger.warning(f"[Misskey API] 提及通知响应格式异常: {type(result)}") return [] + + async def send_message_with_media( + self, + message_type: str, + target_id: str, + text: Optional[str] = None, + media_urls: Optional[List[str]] = None, + local_files: Optional[List[str]] = None, + **kwargs, + ) -> Dict[str, Any]: + """ + 通用消息发送函数:统一处理文本+媒体发送 + + Args: + message_type: 消息类型 ('chat', 'room', 'note') + target_id: 目标ID (用户ID/房间ID/频道ID等) + text: 文本内容 + media_urls: 媒体文件URL列表 + local_files: 本地文件路径列表 + **kwargs: 其他参数(如visibility等) + + Returns: + 发送结果字典 + + Raises: + APIError: 参数错误或发送失败 + """ + if not text and not media_urls and not local_files: + raise APIError("消息内容不能为空:需要文本或媒体文件") + + file_ids = [] + + # 处理远程媒体文件 + if media_urls: + file_ids.extend(await self._process_media_urls(media_urls)) + + # 处理本地文件 + if local_files: + file_ids.extend(await self._process_local_files(local_files)) + + # 根据消息类型发送 + return await self._dispatch_message( + message_type, target_id, text, file_ids, **kwargs + ) + + async def _process_media_urls(self, urls: List[str]) -> List[str]: + """处理远程媒体文件URL列表,返回文件ID列表""" + file_ids = [] + for url in urls: + try: + result = await self.upload_and_find_file(url) + if result and result.get("id"): + file_ids.append(result["id"]) + logger.debug(f"[Misskey API] URL媒体上传成功: {result['id']}") + else: + logger.error(f"[Misskey API] URL媒体上传失败: {url}") + except Exception as e: + logger.error(f"[Misskey API] URL媒体处理失败 {url}: {e}") + # 继续处理其他文件,不中断整个流程 + continue + return file_ids + + async def _process_local_files(self, file_paths: List[str]) -> List[str]: + """处理本地文件路径列表,返回文件ID列表""" + file_ids = [] + for file_path in file_paths: + try: + result = await self.upload_file(file_path) + if result and result.get("id"): + file_ids.append(result["id"]) + logger.debug(f"[Misskey API] 本地文件上传成功: {result['id']}") + else: + logger.error(f"[Misskey API] 本地文件上传失败: {file_path}") + except Exception as e: + logger.error(f"[Misskey API] 本地文件处理失败 {file_path}: {e}") + continue + return file_ids + + async def _dispatch_message( + self, + message_type: str, + target_id: str, + text: Optional[str], + file_ids: List[str], + **kwargs, + ) -> Dict[str, Any]: + """根据消息类型分发到对应的发送方法""" + if message_type == "chat": + # 聊天消息使用 fileId (单数) + payload = {"toUserId": target_id} + if text: + payload["text"] = text + if file_ids: + if len(file_ids) == 1: + payload["fileId"] = file_ids[0] + else: + # 多文件时逐个发送 + results = [] + for file_id in file_ids: + single_payload = payload.copy() + single_payload["fileId"] = file_id + result = await self.send_message(single_payload) + results.append(result) + return {"multiple": True, "results": results} + return await self.send_message(payload) + + elif message_type == "room": + # 房间消息使用 fileId (单数) + payload = {"toRoomId": target_id} + if text: + payload["text"] = text + if file_ids: + if len(file_ids) == 1: + payload["fileId"] = file_ids[0] + else: + # 多文件时逐个发送 + results = [] + for file_id in file_ids: + single_payload = payload.copy() + single_payload["fileId"] = file_id + result = await self.send_room_message(single_payload) + results.append(result) + return {"multiple": True, "results": results} + return await self.send_room_message(payload) + + elif message_type == "note": + # 发帖使用 fileIds (复数) + note_kwargs = { + "text": text, + "file_ids": file_ids or None, + } + # 合并其他参数 + note_kwargs.update(kwargs) + return await self.create_note(**note_kwargs) + + else: + raise APIError(f"不支持的消息类型: {message_type}") diff --git a/astrbot/core/platform/sources/misskey/misskey_event.py b/astrbot/core/platform/sources/misskey/misskey_event.py index 391d10b52..cd737f78e 100644 --- a/astrbot/core/platform/sources/misskey/misskey_event.py +++ b/astrbot/core/platform/sources/misskey/misskey_event.py @@ -40,48 +40,83 @@ def _is_system_command(self, message_str: str) -> bool: return any(message_trimmed.startswith(prefix) for prefix in system_prefixes) async def send(self, message: MessageChain): - content, has_at = serialize_message_chain(message.chain) - - if not content: - logger.debug("[MisskeyEvent] 内容为空,跳过发送") - return - + """发送消息,使用适配器的完整上传和发送逻辑""" try: - original_message_id = getattr(self.message_obj, "message_id", None) - raw_message = getattr(self.message_obj, "raw_message", {}) - - if raw_message and not has_at: - user_data = raw_message.get("user", {}) - user_info = { - "username": user_data.get("username", ""), - "nickname": user_data.get("name", user_data.get("username", "")), - } - content = add_at_mention_if_needed(content, user_info, has_at) - - # 根据会话类型选择发送方式 - if hasattr(self.client, "send_message") and is_valid_user_session_id( - self.session_id - ): - user_id = extract_user_id_from_session_id(self.session_id) - await self.client.send_message(user_id, content) - elif hasattr(self.client, "send_room_message") and is_valid_room_session_id( - self.session_id - ): - room_id = extract_room_id_from_session_id(self.session_id) - await self.client.send_room_message(room_id, content) - elif original_message_id and hasattr(self.client, "create_note"): - visibility, visible_user_ids = resolve_visibility_from_raw_message( - raw_message - ) - await self.client.create_note( - content, - reply_id=original_message_id, - visibility=visibility, - visible_user_ids=visible_user_ids, - ) - elif hasattr(self.client, "create_note"): - logger.debug("[MisskeyEvent] 创建新帖子") - await self.client.create_note(content) + logger.debug( + f"[MisskeyEvent] send 方法被调用,消息链包含 {len(message.chain)} 个组件" + ) + + # 使用适配器的 send_by_session 方法,它包含文件上传逻辑 + from astrbot.core.platform.message_session import MessageSession + from astrbot.core.platform.message_type import MessageType + + # 根据session_id类型确定消息类型 + if is_valid_user_session_id(self.session_id): + message_type = MessageType.FRIEND_MESSAGE + elif is_valid_room_session_id(self.session_id): + message_type = MessageType.GROUP_MESSAGE + else: + message_type = MessageType.FRIEND_MESSAGE # 默认 + + session = MessageSession( + platform_name=self.platform_meta.name, + message_type=message_type, + session_id=self.session_id, + ) + + logger.debug( + f"[MisskeyEvent] 检查适配器方法: hasattr(self.client, 'send_by_session') = {hasattr(self.client, 'send_by_session')}" + ) + + # 调用适配器的 send_by_session 方法 + if hasattr(self.client, "send_by_session"): + logger.debug("[MisskeyEvent] 调用适配器的 send_by_session 方法") + await self.client.send_by_session(session, message) + else: + # 回退到原来的简化发送逻辑 + content, has_at = serialize_message_chain(message.chain) + + if not content: + logger.debug("[MisskeyEvent] 内容为空,跳过发送") + return + + original_message_id = getattr(self.message_obj, "message_id", None) + raw_message = getattr(self.message_obj, "raw_message", {}) + + if raw_message and not has_at: + user_data = raw_message.get("user", {}) + user_info = { + "username": user_data.get("username", ""), + "nickname": user_data.get( + "name", user_data.get("username", "") + ), + } + content = add_at_mention_if_needed(content, user_info, has_at) + + # 根据会话类型选择发送方式 + if hasattr(self.client, "send_message") and is_valid_user_session_id( + self.session_id + ): + user_id = extract_user_id_from_session_id(self.session_id) + await self.client.send_message(user_id, content) + elif hasattr( + self.client, "send_room_message" + ) and is_valid_room_session_id(self.session_id): + room_id = extract_room_id_from_session_id(self.session_id) + await self.client.send_room_message(room_id, content) + elif original_message_id and hasattr(self.client, "create_note"): + visibility, visible_user_ids = resolve_visibility_from_raw_message( + raw_message + ) + await self.client.create_note( + content, + reply_id=original_message_id, + visibility=visibility, + visible_user_ids=visible_user_ids, + ) + elif hasattr(self.client, "create_note"): + logger.debug("[MisskeyEvent] 创建新帖子") + await self.client.create_note(content) await super().send(message) diff --git a/astrbot/core/platform/sources/misskey/misskey_utils.py b/astrbot/core/platform/sources/misskey/misskey_utils.py index 9a96b453f..d10b29431 100644 --- a/astrbot/core/platform/sources/misskey/misskey_utils.py +++ b/astrbot/core/platform/sources/misskey/misskey_utils.py @@ -5,6 +5,68 @@ from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType +class FileIDExtractor: + """从 API 响应中提取文件 ID 的帮助类(无状态)。""" + + @staticmethod + def extract_file_id(result: Any) -> Optional[str]: + if not isinstance(result, dict): + return None + + id_paths = [ + lambda r: r.get("createdFile", {}).get("id"), + lambda r: r.get("file", {}).get("id"), + lambda r: r.get("id"), + ] + + for p in id_paths: + try: + if fid := p(result): + return fid + except Exception: + continue + + return None + + +class MessagePayloadBuilder: + """构建不同类型消息负载的帮助类(无状态)。""" + + @staticmethod + def build_chat_payload( + user_id: str, text: Optional[str], file_id: Optional[str] = None + ) -> Dict[str, Any]: + payload = {"toUserId": user_id} + if text: + payload["text"] = text + if file_id: + payload["fileId"] = file_id + return payload + + @staticmethod + def build_room_payload( + room_id: str, text: Optional[str], file_id: Optional[str] = None + ) -> Dict[str, Any]: + payload = {"toRoomId": room_id} + if text: + payload["text"] = text + if file_id: + payload["fileId"] = file_id + return payload + + @staticmethod + def build_note_payload( + text: Optional[str], file_ids: Optional[List[str]] = None, **kwargs + ) -> Dict[str, Any]: + payload: Dict[str, Any] = {} + if text: + payload["text"] = text + if file_ids: + payload["fileIds"] = file_ids + payload |= kwargs + return payload + + def serialize_message_chain(chain: List[Any]) -> Tuple[str, bool]: """将消息链序列化为文本字符串""" text_parts = [] @@ -15,8 +77,11 @@ def process_component(component): if isinstance(component, Comp.Plain): return component.text elif isinstance(component, Comp.File): - file_name = getattr(component, "name", "文件") - return f"[文件: {file_name}]" + # 为文件组件返回占位符,但适配器仍会处理原组件 + return "[文件]" + elif isinstance(component, Comp.Image): + # 为图片组件返回占位符,但适配器仍会处理原组件 + return "[图片]" elif isinstance(component, Comp.At): has_at = True return f"@{component.qq}" @@ -43,15 +108,22 @@ def process_component(component): def resolve_message_visibility( - user_id: Optional[str], - user_cache: Dict[str, Any], - self_id: Optional[str], + user_id: Optional[str] = None, + user_cache: Optional[Dict[str, Any]] = None, + self_id: Optional[str] = None, + raw_message: Optional[Dict[str, Any]] = None, default_visibility: str = "public", ) -> Tuple[str, Optional[List[str]]]: - """解析 Misskey 消息的可见性设置""" + """解析 Misskey 消息的可见性设置 + + 可以从 user_cache 或 raw_message 中解析,支持两种调用方式: + 1. 基于 user_cache: resolve_message_visibility(user_id, user_cache, self_id) + 2. 基于 raw_message: resolve_message_visibility(raw_message=raw_message, self_id=self_id) + """ visibility = default_visibility visible_user_ids = None + # 优先从 user_cache 解析 if user_id and user_cache: user_info = user_cache.get(user_id) if user_info: @@ -66,38 +138,36 @@ def resolve_message_visibility( visible_user_ids = [uid for uid in visible_user_ids if uid] else: visibility = original_visibility + return visibility, visible_user_ids + + # 回退到从 raw_message 解析 + if raw_message: + original_visibility = raw_message.get("visibility", default_visibility) + if original_visibility == "specified": + visibility = "specified" + original_visible_users = raw_message.get("visibleUserIds", []) + sender_id = raw_message.get("userId", "") + + users_to_include = [] + if sender_id: + users_to_include.append(sender_id) + if self_id: + users_to_include.append(self_id) + + visible_user_ids = list(set(original_visible_users + users_to_include)) + visible_user_ids = [uid for uid in visible_user_ids if uid] + else: + visibility = original_visibility return visibility, visible_user_ids +# 保留旧函数名作为向后兼容的别名 def resolve_visibility_from_raw_message( raw_message: Dict[str, Any], self_id: Optional[str] = None ) -> Tuple[str, Optional[List[str]]]: - """从原始消息数据中解析可见性设置""" - visibility = "public" - visible_user_ids = None - - if not raw_message: - return visibility, visible_user_ids - - original_visibility = raw_message.get("visibility", "public") - if original_visibility == "specified": - visibility = "specified" - original_visible_users = raw_message.get("visibleUserIds", []) - sender_id = raw_message.get("userId", "") - - users_to_include = [] - if sender_id: - users_to_include.append(sender_id) - if self_id: - users_to_include.append(self_id) - - visible_user_ids = list(set(original_visible_users + users_to_include)) - visible_user_ids = [uid for uid in visible_user_ids if uid] - else: - visibility = original_visibility - - return visibility, visible_user_ids + """从原始消息数据中解析可见性设置(已弃用,使用 resolve_message_visibility 替代)""" + return resolve_message_visibility(raw_message=raw_message, self_id=self_id) def is_valid_user_session_id(session_id: Union[str, Any]) -> bool: @@ -128,6 +198,20 @@ def is_valid_room_session_id(session_id: Union[str, Any]) -> bool: ) +def is_valid_chat_session_id(session_id: Union[str, Any]) -> bool: + """检查 session_id 是否是有效的聊天 session_id (仅限chat%前缀)""" + if not isinstance(session_id, str) or "%" not in session_id: + return False + + parts = session_id.split("%") + return ( + len(parts) == 2 + and parts[0] == "chat" + and bool(parts[1]) + and parts[1] != "unknown" + ) + + def extract_user_id_from_session_id(session_id: str) -> str: """从 session_id 中提取用户 ID""" if "%" in session_id: @@ -197,6 +281,22 @@ def process_files( return file_parts +def format_poll(poll: Dict[str, Any]) -> str: + """将 Misskey 的 poll 对象格式化为可读字符串。""" + if not poll or not isinstance(poll, dict): + return "" + multiple = poll.get("multiple", False) + choices = poll.get("choices", []) + text_choices = [ + f"({idx}) {c.get('text', '')} [{c.get('votes', 0)}票]" + for idx, c in enumerate(choices, start=1) + ] + parts = ["[投票]", ("允许多选" if multiple else "单选")] + ( + ["选项: " + ", ".join(text_choices)] if text_choices else [] + ) + return " ".join(parts) + + def extract_sender_info( raw_data: Dict[str, Any], is_chat: bool = False ) -> Dict[str, Any]: @@ -248,7 +348,7 @@ def create_base_message( else: session_prefix = "note" session_id = f"{session_prefix}%{sender_info['sender_id']}" - message.type = MessageType.FRIEND_MESSAGE + message.type = MessageType.OTHER_MESSAGE message.session_id = ( session_id if sender_info["sender_id"] else f"{session_prefix}%unknown" @@ -325,3 +425,106 @@ def cache_room_info( "visibility": "specified", "visible_user_ids": [client_self_id], } + + +async def resolve_component_url_or_path( + comp: Any, +) -> Tuple[Optional[str], Optional[str]]: + """尝试从组件解析可上传的远程 URL 或本地路径。 + + 返回 (url_candidate, local_path)。两者可能都为 None。 + 这个函数尽量不抛异常,调用方可按需处理 None。 + """ + url_candidate = None + local_path = None + + async def _get_str_value(coro_or_val): + """辅助函数:统一处理协程或普通值""" + try: + if hasattr(coro_or_val, "__await__"): + result = await coro_or_val + else: + result = coro_or_val + return result if isinstance(result, str) else None + except Exception: + return None + + try: + # 1. 尝试异步方法 + for method in ["convert_to_file_path", "get_file", "register_to_file_service"]: + if not hasattr(comp, method): + continue + try: + value = await _get_str_value(getattr(comp, method)()) + if value: + if value.startswith("http"): + url_candidate = value + break + else: + local_path = value + except Exception: + continue + + # 2. 尝试 get_file(True) 获取可直接访问的 URL + if not url_candidate and hasattr(comp, "get_file"): + try: + value = await _get_str_value(comp.get_file(True)) + if value and value.startswith("http"): + url_candidate = value + except Exception: + pass + + # 3. 回退到同步属性 + if not url_candidate and not local_path: + for attr in ("file", "url", "path", "src", "source"): + try: + value = getattr(comp, attr, None) + if value and isinstance(value, str): + if value.startswith("http"): + url_candidate = value + break + else: + local_path = value + break + except Exception: + continue + + except Exception: + pass + + return url_candidate, local_path + + +def summarize_component_for_log(comp: Any) -> Dict[str, Any]: + """生成适合日志的组件属性字典(尽量不抛异常)。""" + attrs = {} + for a in ("file", "url", "path", "src", "source", "name"): + try: + v = getattr(comp, a, None) + if v is not None: + attrs[a] = v + except Exception: + continue + return attrs + + +async def upload_local_with_retries( + api: Any, + local_path: str, + preferred_name: Optional[str], + folder_id: Optional[str], +) -> Optional[str]: + """尝试本地上传,返回 file id 或 None。如果文件类型不允许则直接失败。""" + try: + res = await api.upload_file(local_path, preferred_name, folder_id) + if isinstance(res, dict): + fid = res.get("id") or (res.get("raw") or {}).get("createdFile", {}).get( + "id" + ) + if fid: + return str(fid) + except Exception: + # 上传失败,直接返回 None,让上层处理错误 + return None + + return None