diff --git a/astrbot/api/event/filter/__init__.py b/astrbot/api/event/filter/__init__.py index dd737e3ff..830de7892 100644 --- a/astrbot/api/event/filter/__init__.py +++ b/astrbot/api/event/filter/__init__.py @@ -12,6 +12,8 @@ register_llm_tool as llm_tool, register_on_decorating_result as on_decorating_result, register_after_message_sent as after_message_sent, + register_on_star_activated as on_star_activated, + register_on_star_deactivated as on_star_deactivated, ) from astrbot.core.star.filter.event_message_type import ( @@ -46,4 +48,6 @@ "on_decorating_result", "after_message_sent", "on_llm_response", + "on_star_activated", + "on_star_deactivated", ] diff --git a/astrbot/core/star/register/__init__.py b/astrbot/core/star/register/__init__.py index 55a4393da..b7db61ed9 100644 --- a/astrbot/core/star/register/__init__.py +++ b/astrbot/core/star/register/__init__.py @@ -14,6 +14,8 @@ register_agent, register_on_decorating_result, register_after_message_sent, + register_on_star_activated, + register_on_star_deactivated, ) __all__ = [ @@ -32,4 +34,6 @@ "register_agent", "register_on_decorating_result", "register_after_message_sent", + "register_on_star_activated", + "register_on_star_deactivated", ] diff --git a/astrbot/core/star/register/star_handler.py b/astrbot/core/star/register/star_handler.py index 101f3a95f..20138a9de 100644 --- a/astrbot/core/star/register/star_handler.py +++ b/astrbot/core/star/register/star_handler.py @@ -450,3 +450,31 @@ def decorator(awaitable): return awaitable return decorator + + +def register_on_star_activated(star_name: str = None, **kwargs): + """当指定插件被激活时""" + + def decorator(awaitable): + handler_md = get_handler_or_create( + awaitable, EventType.OnStarActivatedEvent, **kwargs + ) + if star_name: + handler_md.extras_configs["target_star_name"] = star_name + return awaitable + + return decorator + + +def register_on_star_deactivated(star_name: str = None, **kwargs): + """当指定插件被停用时""" + + def decorator(awaitable): + handler_md = get_handler_or_create( + awaitable, EventType.OnStarDeactivatedEvent, **kwargs + ) + if star_name: + handler_md.extras_configs["target_star_name"] = star_name + return awaitable + + return decorator diff --git a/astrbot/core/star/star.py b/astrbot/core/star/star.py index 0563e8cc8..76698efe6 100644 --- a/astrbot/core/star/star.py +++ b/astrbot/core/star/star.py @@ -32,6 +32,8 @@ class StarMetadata: """插件版本""" repo: str | None = None """插件仓库地址""" + dependencies: list[str] = field(default_factory=list) + """插件依赖列表""" star_cls_type: type[Star] | None = None """插件的类对象的类型""" diff --git a/astrbot/core/star/star_handler.py b/astrbot/core/star/star_handler.py index 43a74396a..3ed83ec2f 100644 --- a/astrbot/core/star/star_handler.py +++ b/astrbot/core/star/star_handler.py @@ -98,6 +98,8 @@ class EventType(enum.Enum): OnCallingFuncToolEvent = enum.auto() # 调用函数工具 OnAfterMessageSentEvent = enum.auto() # 发送消息后 + OnStarActivatedEvent = enum.auto() # 插件启用 + OnStarDeactivatedEvent = enum.auto() # 插件禁用 @dataclass class StarHandlerMetadata: diff --git a/astrbot/core/star/star_manager.py b/astrbot/core/star/star_manager.py index 5fb1b1dfa..2a952cf80 100644 --- a/astrbot/core/star/star_manager.py +++ b/astrbot/core/star/star_manager.py @@ -30,6 +30,8 @@ from .star import star_map, star_registry from .star_handler import star_handlers_registry from .updator import PluginUpdator +from .star_handler import EventType, StarHandlerMetadata +import networkx as nx try: from watchfiles import PythonFilter, awatch @@ -144,13 +146,11 @@ def _get_modules(path): if os.path.exists(os.path.join(path, d, "main.py")) or os.path.exists( os.path.join(path, d, d + ".py") ): - modules.append( - { - "pname": d, - "module": module_str, - "module_path": os.path.join(path, d, module_str), - } - ) + modules.append({ + "pname": d, + "module": module_str, + "module_path": os.path.join(path, d, module_str), + }) return modules def _get_plugin_modules(self) -> list[dict]: @@ -226,6 +226,7 @@ def _load_plugin_metadata(plugin_path: str, plugin_obj=None) -> StarMetadata | N desc=metadata["desc"], version=metadata["version"], repo=metadata["repo"] if "repo" in metadata else None, + dependencies=metadata.get("dependencies", []), ) return metadata @@ -321,25 +322,17 @@ async def reload(self, specified_plugin_name=None): star_handlers_registry.clear() star_map.clear() star_registry.clear() + plugin_modules = await self._get_load_order() + result = await self.load(plugin_modules=plugin_modules) else: # 只重载指定插件 - smd = star_map.get(specified_module_path) - if smd: - try: - await self._terminate_plugin(smd) - except Exception as e: - logger.warning(traceback.format_exc()) - logger.warning( - f"插件 {smd.name} 未被正常终止: {str(e)}, 可能会导致该插件运行不正常。" - ) - if smd.name: - await self._unbind_plugin(smd.name, specified_module_path) - - result = await self.load(specified_module_path) + result = await self.batch_reload( + specified_module_path=specified_module_path + ) return result - async def load(self, specified_module_path=None, specified_dir_name=None): + async def load(self, plugin_modules=None): """载入插件。 当 specified_module_path 或者 specified_dir_name 不为 None 时,只载入指定的插件。 @@ -356,10 +349,11 @@ async def load(self, specified_module_path=None, specified_dir_name=None): inactivated_llm_tools = await sp.global_get("inactivated_llm_tools", []) alter_cmd = await sp.global_get("alter_cmd", {}) - plugin_modules = self._get_plugin_modules() if plugin_modules is None: return False, "未找到任何插件模块" - + logger.info( + f"正在按顺序加载插件: {[plugin_module['pname'] for plugin_module in plugin_modules]}" + ) fail_rec = "" # 导入插件模块,并尝试实例化插件类 @@ -375,12 +369,6 @@ async def load(self, specified_module_path=None, specified_dir_name=None): path = "data.plugins." if not reserved else "packages." path += root_dir_name + "." + module_str - # 检查是否需要载入指定的插件 - if specified_module_path and path != specified_module_path: - continue - if specified_dir_name and root_dir_name != specified_dir_name: - continue - logger.info(f"正在载入插件 {root_dir_name} ...") # 尝试导入模块 @@ -451,6 +439,9 @@ async def load(self, specified_module_path=None, specified_dir_name=None): metadata.star_cls = metadata.star_cls_type( context=self.context ) + await self._trigger_star_lifecycle_event( + EventType.OnStarActivatedEvent, metadata + ) else: logger.info(f"插件 {metadata.name} 已被禁用。") @@ -622,7 +613,8 @@ async def install_plugin(self, repo_url: str, proxy=""): plugin_path = await self.updator.install(repo_url, proxy) # reload the plugin dir_name = os.path.basename(plugin_path) - await self.load(specified_dir_name=dir_name) + plugin_modules = await self._get_load_order(specified_dir_name=dir_name) + await self.batch_reload(plugin_modules=plugin_modules) # Get the plugin metadata to return repo info plugin = self.context.get_registered_star(dir_name) @@ -778,8 +770,7 @@ async def turn_off_plugin(self, plugin_name: str): plugin.activated = False - @staticmethod - async def _terminate_plugin(star_metadata: StarMetadata): + async def _terminate_plugin(self, star_metadata: StarMetadata): """终止插件,调用插件的 terminate() 和 __del__() 方法""" logger.info(f"正在终止插件 {star_metadata.name} ...") @@ -788,14 +779,18 @@ async def _terminate_plugin(star_metadata: StarMetadata): logger.debug(f"插件 {star_metadata.name} 未被激活,不需要终止,跳过。") return + await self._trigger_star_lifecycle_event( + EventType.OnStarDeactivatedEvent, star_metadata + ) + if star_metadata.star_cls is None: return - if '__del__' in star_metadata.star_cls_type.__dict__: + if "__del__" in star_metadata.star_cls_type.__dict__: asyncio.get_event_loop().run_in_executor( None, star_metadata.star_cls.__del__ ) - elif 'terminate' in star_metadata.star_cls_type.__dict__: + elif "terminate" in star_metadata.star_cls_type.__dict__: await star_metadata.star_cls.terminate() async def turn_on_plugin(self, plugin_name: str): @@ -832,7 +827,8 @@ async def install_plugin_from_file(self, zip_file_path: str): except BaseException as e: logger.warning(f"删除插件压缩包失败: {str(e)}") # await self.reload() - await self.load(specified_dir_name=dir_name) + plugin_modules = await self._get_load_order(specified_dir_name=dir_name) + await self.batch_reload(plugin_modules=plugin_modules) # Get the plugin metadata to return repo info plugin = self.context.get_registered_star(dir_name) @@ -865,3 +861,147 @@ async def install_plugin_from_file(self, zip_file_path: str): } return plugin_info + + async def _trigger_star_lifecycle_event( + self, event_type: EventType, star_metadata: StarMetadata + ): + """ + 内部辅助函数,用于触发插件(Star)相关的生命周期事件。 + Args: + event_type: 要触发的事件类型 (EventType.OnStarActivatedEvent 或 EventType.OnStarDeactivatedEvent)。 + star_metadata: 触发事件的插件的 StarMetadata 对象。 + """ + handlers_to_run: list[StarHandlerMetadata] = [] + # 获取所有监听该事件类型的 handlers + handlers = star_handlers_registry.get_handlers_by_event_type(event_type) + + for handler in handlers: + # 检查这个 handler 是否监听了特定的插件名 + target_star_name = handler.extras_configs.get("target_star_name") + if target_star_name and target_star_name == star_metadata.name: + # 如果指定了目标插件名,则只在匹配时添加 + handlers_to_run.append(handler) + + for handler in handlers_to_run: + try: + # 调用插件的钩子函数,并传入 StarMetadata 对象 + logger.info( + f"hook({event_type.name}) -> {star_map[handler.handler_module_path].name} - {handler.handler_name} (目标插件: {star_metadata.name})" + ) + await handler.handler(star_metadata) # 传递参数 + except Exception: + logger.error( + f"执行插件 {handler.handler_name} 的 {event_type.name} 钩子时出错: {traceback.format_exc()}" + ) + + def _get_plugin_dir_path(self, root_dir_name: str, is_reserved: bool) -> str: + """根据插件的根目录名和是否为保留插件,返回插件的完整文件路径。""" + return ( + os.path.join(self.plugin_store_path, root_dir_name) + if not is_reserved + else os.path.join(self.reserved_plugin_path, root_dir_name) + ) + + def _build_module_path(self, plugin_module_info: dict) -> str: + """根据插件模块信息构建完整的模块路径。""" + reserved = plugin_module_info.get("reserved", False) + path_prefix = "packages." if reserved else "data.plugins." + return ( + f"{path_prefix}{plugin_module_info['pname']}.{plugin_module_info['module']}" + ) + + async def _get_load_order( + self, specified_dir_name: str = None, specified_module_path: str = None + ): + star_graph = self._build_star_graph() + if star_graph is None: + return None + try: + if specified_dir_name: + for node in star_graph: + if ( + star_graph.nodes[node]["data"].get("pname") + == specified_dir_name + ): + dependent_nodes = nx.descendants(star_graph, node) + sub_graph = star_graph.subgraph(dependent_nodes.union({node})) + load_order = list(nx.topological_sort(sub_graph)) + return [star_graph.nodes[node]["data"] for node in load_order] + elif specified_module_path: + for node in star_graph: + if specified_module_path == self._build_module_path( + star_graph.nodes[node].get("data") + ): + dependent_nodes = nx.descendants(star_graph, node) + sub_graph = star_graph.subgraph(dependent_nodes.union({node})) + load_order = list(nx.topological_sort(sub_graph)) + return [star_graph.nodes[node]["data"] for node in load_order] + else: + sorted_nodes = list(nx.topological_sort(star_graph)) + + reserved_plugins = [ + star_graph.nodes[node]["data"] + for node in sorted_nodes + if star_graph.nodes[node]["data"].get("reserved", False) + ] + non_reserved_plugins = [ + star_graph.nodes[node]["data"] + for node in sorted_nodes + if not star_graph.nodes[node]["data"].get("reserved", False) + ] + + return reserved_plugins + non_reserved_plugins + + except nx.NetworkXUnfeasible: + logger.error("出现循环依赖,无法确定加载顺序,按自然顺序加载") + return [star_graph.nodes[node]["data"] for node in star_graph] + + def _build_star_graph(self): + plugin_modules = self._get_plugin_modules() + if plugin_modules is None: + return None + G = nx.DiGraph() + for plugin_module in plugin_modules: + root_dir_name = plugin_module["pname"] + is_reserved = plugin_module.get("reserved", False) + plugin_dir_path = self._get_plugin_dir_path(root_dir_name, is_reserved) + G.add_node(root_dir_name, data=plugin_module) + try: + metadata = self._load_plugin_metadata(plugin_dir_path) + if metadata: + for dep_name in metadata.dependencies: + G.add_edge(root_dir_name, dep_name) + except Exception: + pass + # 过滤不存在的依赖(出边没有data, 就删除指向的节点) + nodes_to_remove = [] + for node_name in list(G.nodes()): + for neighbor in list(G.neighbors(node_name)): + if G.nodes[neighbor].get("data") is None: + nodes_to_remove.append(neighbor) + logger.warning( + f"插件 {node_name} 声明依赖 {neighbor}, 但该插件未被发现,跳过加载。" + ) + for node in nodes_to_remove: + G.remove_node(node) + return G + + async def batch_reload(self, specified_module_path=None, plugin_modules=None): + if not plugin_modules: + plugin_modules = await self._get_load_order( + specified_module_path=specified_module_path + ) + for plugin_module in plugin_modules: + specified_module_path = self._build_module_path(plugin_module) + smd = star_map.get(specified_module_path) + if smd: + try: + await self._terminate_plugin(smd) + except Exception as e: + logger.warning(traceback.format_exc()) + logger.warning( + f"插件 {smd.name} 未被正常终止: {str(e)}, 可能会导致该插件运行不正常。" + ) + await self._unbind_plugin(smd.name, specified_module_path) + + return await self.load(plugin_modules=plugin_modules) diff --git a/pyproject.toml b/pyproject.toml index 336b38433..aa53582e9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,7 @@ dependencies = [ "watchfiles>=1.0.5", "websockets>=15.0.1", "wechatpy>=1.8.18", + "networkx>=3.4.2", ] [project.scripts] diff --git a/requirements.txt b/requirements.txt index bd8f0eca0..d98463869 100644 --- a/requirements.txt +++ b/requirements.txt @@ -39,4 +39,5 @@ faiss-cpu aiosqlite py-cord>=2.6.1 slack-sdk -pydub \ No newline at end of file +pydub +networkx