Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 43 additions & 37 deletions astrbot/core/pipeline/process_stage/stage.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import List, Union, AsyncGenerator
from typing import Union, AsyncGenerator
from ..stage import Stage, register_stage
from ..context import PipelineContext
from .method.llm_request import LLMRequestSubStage
from .method.star_request import StarRequestSubStage
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star.star_handler import StarHandlerMetadata
from astrbot.core.provider.entities import ProviderRequest
from astrbot.core.utils.session_lock import session_lock_manager
from astrbot.core import logger


Expand All @@ -25,44 +25,50 @@ async def process(
self, event: AstrMessageEvent
) -> Union[None, AsyncGenerator[None, None]]:
"""处理事件"""
activated_handlers: List[StarHandlerMetadata] = event.get_extra(
"activated_handlers"
)
# 有插件 Handler 被激活
if activated_handlers:
async for resp in self.star_request_sub_stage.process(event):
# 生成器返回值处理
if isinstance(resp, ProviderRequest):
# Handler 的 LLM 请求
event.set_extra("provider_request", resp)
_t = False
async for _ in self.llm_request_sub_stage.process(event):
_t = True
yield
if not _t:
umo = event.unified_msg_origin

logger.debug(f"Ready to process event, acquiring session lock. umo = {umo}")
async with session_lock_manager.acquire_lock(session_id=umo):
activated_handlers = event.get_extra("activated_handlers")
if not isinstance(activated_handlers, list):
logger.error(
"activated_handlers is not a list, will skip processing plugin handlers and llm request."
)
return
# 有插件 handler 被激活
if activated_handlers:
async for resp in self.star_request_sub_stage.process(event):
if isinstance(resp, ProviderRequest):
# handler 的 LLM 请求
event.set_extra("provider_request", resp)
_t = False
async for _ in self.llm_request_sub_stage.process(event):
_t = True
yield
if not _t:
yield
else:
yield
else:
yield

# 调用 LLM 相关请求
if not self.ctx.astrbot_config["provider_settings"].get("enable", True):
return
# 调用 LLM 相关请求
if not self.ctx.astrbot_config["provider_settings"].get("enable", True):
return

if (
not event._has_send_oper
and event.is_at_or_wake_command
and not event.call_llm
):
# 是否有过发送操作 and 是否是被 @ 或者通过唤醒前缀
if (
event.get_result() and not event.get_result().is_stopped()
) or not event.get_result():
# 事件没有终止传播
provider = self.ctx.plugin_manager.context.get_using_provider()
not event._has_send_oper
and event.is_at_or_wake_command
and not event.call_llm
):
# 是否有过发送操作 and 是否是被 @ 或者通过唤醒前缀
if (
event.get_result() and not event.get_result().is_stopped()
) or not event.get_result():
# 事件没有终止传播
provider = self.ctx.plugin_manager.context.get_using_provider()

if not provider:
logger.info("未找到可用的 LLM 提供商,请先前往配置服务提供商。")
return
if not provider:
logger.info("未找到可用的 LLM 提供商,请先前往配置服务提供商。")
return

async for _ in self.llm_request_sub_stage.process(event):
yield
async for _ in self.llm_request_sub_stage.process(event):
yield