Skip to content

Feat: 添加问题预测节点 #1427

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 156 additions & 2 deletions src/backend/bisheng/workflow/nodes/input/input.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
import json
import os
import shutil
import tempfile
from typing import Any
from typing import Any, Dict, List

from loguru import logger


from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.runnables import RunnableConfig


from bisheng.api.services.knowledge import KnowledgeService
from bisheng.api.services.knowledge_imp import decide_vectorstores, read_chunk_text
from bisheng.api.services.llm import LLMService
from bisheng.api.utils import md5_hash
from bisheng.api.v1.schemas import FileProcessBase
from bisheng.cache.utils import file_download
from bisheng.chat.types import IgnoreException
from bisheng.workflow.callback.event import GuideQuestionData
from bisheng.workflow.nodes.base import BaseNode


Expand All @@ -35,11 +42,24 @@ def __init__(self, *args, **kwargs):
if self.is_dialog_input():
new_node_params['user_input'] = self.node_params['user_input']
new_node_params['dialog_files_content'] = self.node_params.get('dialog_files_content', [])
predict_config = self.node_params.get("input_prediction_config", {})
if predict_config.get("open", False):
self._predict_input_open = True
self._predict_count = predict_config.get("predict_count", 3)
self._history_limit = predict_config.get("history_limit", 10)
predict_model = predict_config.get("model_id")
self._predict_llm = LLMService.get_bisheng_llm(
model_id=predict_model,
temperature=predict_config.get("temperature", 0.7),
cache=False,
)
else:
self._predict_input_open = False
else:
for value_info in self.node_params['form_input']:
new_node_params[value_info['key']] = value_info['value']
self._node_params_map[value_info['key']] = value_info

self.node_params = new_node_params
self._image_ext = ['png', 'jpg', 'jpeg', 'bmp']

Expand All @@ -60,6 +80,8 @@ def is_dialog_input(self):

def get_input_schema(self) -> Any:
if self.is_dialog_input():
if self._predict_input_open:
self._predict_input(self.exec_unique_id)
user_input_info = self.node_data.get_variable_info('user_input')
user_input_info.value = [
self.node_data.get_variable_info('dialog_files_content'),
Expand Down Expand Up @@ -255,3 +277,135 @@ def parse_upload_file(self, key: str, key_info: dict, value: str) -> dict | None
key_info['file_path']: original_file_path,
key_info['image_file']: image_files_path
}

def _build_system_prompt(self) -> str:
"""构建系统提示词"""

return f"""你是一个智能助手,专门负责分析用户的对话历史,预测用户可能会问的下一个问题。

请基于以下对话历史,分析用户的意图、兴趣点和对话发展趋势,预测用户最可能问的{self._predict_count}个问题。

分析要点:
1. 用户的关注焦点和兴趣方向
2. 对话的逻辑发展趋势
3. 用户可能的深入需求
4. 相关的延伸话题

请确保预测的问题:
- 与对话上下文相关
- 符合用户的交流风格
- 具有实际价值和可操作性
- 按照可能性从高到低排序"""

def _build_output_format_prompt(self) -> str:
"""构建输出格式说明"""
return """
请以JSON格式返回结果:
{
"predicted_questions": [
{
"question": "预测的问题内容",
"probability": "可能性评分(0-1)",
"reason": "预测理由"
}
],
"analysis": "对话趋势分析总结"
}

确保返回有效的JSON格式。"""

def _get_chat_history(self) -> List[Dict]:
"""获取聊天历史"""
try:
# 从图状态获取历史消息
history_list = self.graph_state.get_history_list(self._history_limit)

# 转换为更易处理的格式
formatted_history = []
for msg in history_list:
if hasattr(msg, "content"):
content = msg.content
if isinstance(content, list) and len(content) > 0:
# 处理多模态消息,只取文本部分
text_content = ""
for item in content:
if isinstance(item, dict) and item.get("type") == "text":
text_content += item.get("text", "")
content = text_content

role = "user" if hasattr(msg, "__class__") and "Human" in msg.__class__.__name__ else "assistant"
formatted_history.append({"role": role, "content": str(content)})

return formatted_history
except Exception as e:
logger.warning(f"获取聊天历史失败: {e}")
return []

def _format_history_for_prompt(self, history: List[Dict]) -> str:
"""将历史消息格式化为提示词"""
formatted = "对话历史:\n"
for i, msg in enumerate(history, 1):
role_name = "用户" if msg["role"] == "user" else "助手"
formatted += f"{i}. {role_name}: {msg['content']}\n"

return formatted

def _parse_llm_output(self, output: str) -> Dict:
"""解析LLM输出"""
try:
# 尝试解析JSON
# 清理可能的markdown代码块标记
cleaned_output = output.strip()
if cleaned_output.startswith("```json"):
cleaned_output = cleaned_output[7:]
if cleaned_output.endswith("```"):
cleaned_output = cleaned_output[:-3]
cleaned_output = cleaned_output.strip()

result = json.loads(cleaned_output)
return {
"questions": result.get("predicted_questions", []),
"analysis": result.get("analysis", ""),
}

except Exception as e:
logger.error(f"解析LLM输出失败: {e}")
return {
"questions": [{"question": "解析预测结果失败", "probability": 0.0, "reason": str(e)}],
"analysis": "输出解析出错",
"format": "error",
"raw_output": output,
}

def _predict_input(self, unique_id: str) -> Dict:
chat_history = self._get_chat_history()
if not chat_history:
return
self._history_used = chat_history

# 构建提示词
system_prompt = self._build_system_prompt()
self._system_prompt_used = system_prompt

history_text = self._format_history_for_prompt(chat_history)
output_format = self._build_output_format_prompt()

user_prompt = f"{history_text}\n\n{output_format}"

# 构建消息
messages = [SystemMessage(content=system_prompt), HumanMessage(content=user_prompt)]

# 调用LLM
result = self._predict_llm.invoke(messages, config=RunnableConfig())

# 解析结果
parsed_result = self._parse_llm_output(result.content)


self.callback_manager.on_guide_question(
GuideQuestionData(
node_id=self.id,
guide_question=[q["question"] for q in parsed_result["questions"]],
unique_id=unique_id,
)
)
2 changes: 1 addition & 1 deletion src/backend/bisheng/workflow/nodes/node_manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
NodeType.CONDITION.value: ConditionNode,
NodeType.AGENT.value: AgentNode,
NodeType.CODE.value: CodeNode,
NodeType.LLM.value: LLMNode
NodeType.LLM.value: LLMNode,
}


Expand Down
10 changes: 9 additions & 1 deletion src/frontend/platform/public/locales/en/flow.json
Original file line number Diff line number Diff line change
Expand Up @@ -165,5 +165,13 @@
"storeFilesSentInCurrentSession": "Store files sent in the current session",
"displayNameTooltip": "Displayed on the user conversation page",
"variableNameTooltipFile": "Used to store content entered on the user conversation page; can be selected from the temporary session file list",
"variableNameTooltipOther": "Used to store content entered on the user conversation page; can be referenced in other nodes"
"variableNameTooltipOther": "Used to store content entered on the user conversation page; can be referenced in other nodes",
"modelSettings": "Model Settings",
"model": "Model",
"modelRequired": "Model cannot be empty",
"selectModel": "Select Model",
"predictCount": "Prediction Count",
"historyCount": "History Count",
"temperature": "Temperature",
"predictionSettings": "Prediction Settings"
}
10 changes: 9 additions & 1 deletion src/frontend/platform/public/locales/zh/flow.json
Original file line number Diff line number Diff line change
Expand Up @@ -165,5 +165,13 @@
"storeFilesSentInCurrentSession": "存储当前会话中发送的文件",
"displayNameTooltip": "用户会话页面展示此名称",
"variableNameTooltipFile": "用于存储用户会话页面填写的内容,可在临时会话文件列表中选择此变量",
"variableNameTooltipOther": "用于存储用户会话页面填写的内容,可在其他节点中引用此变量"
"variableNameTooltipOther": "用于存储用户会话页面填写的内容,可在其他节点中引用此变量",
"modelSettings": "模型设置",
"model": "模型",
"modelRequired": "模型不可为空",
"selectModel": "选择模型",
"predictCount": "预测问题数量",
"historyCount": "历史消息数量",
"temperature": "温度",
"predictionSettings": "预测配置"
}
28 changes: 28 additions & 0 deletions src/frontend/platform/src/controllers/API/workflow.ts
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,20 @@ const workflowTemplate = [
"tab": "dialog_input",
"help": "提取上传文件中的图片文件,当助手或大模型节点使用多模态大模型时,可传入此图片。"
},
{
"key": "input_prediction_config",
"label": "输入预测配置",
"type": "input_prediction_config",
"tab": "dialog_input",
"value": {
"open": false,
"model_id": "",
"temperature": 0.7,
"predict_count": 3,
"history_count": 10
},
"help": "根据历史消息预测用户下一个可能的输入内容。"
},
{
"key": "form_input",
"global": "item:form_input",
Expand Down Expand Up @@ -955,6 +969,20 @@ const workflowTemplateEN = [
"tab": "dialog_input",
"help": "Extract the image file from the uploaded file. When the assistant or large model node uses the MultiModal Machine Learning large model, this image can be passed in."
},
{
"key": "input_prediction_config",
"label": "Input Prediction Configuration",
"type": "input_prediction_config",
"tab": "dialog_input",
"value": {
"open": false,
"model_id": "",
"temperature": 0.7,
"predict_count": 3,
"history_count": 10
},
"help": "Predict the user's next possible input based on message history."
},
{
"global": "item:form_input",
"key": "form_input",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { cname } from "@/components/bs-ui/utils";
import { BookOpenTextIcon, Bot, Brain, Code2, FileDown, FileSearch, FlagTriangleRight, Hammer, Home, Keyboard, MessagesSquareIcon, Split } from "lucide-react";
import { BookOpenTextIcon, Bot, Brain, Code2, FileDown, FileSearch, FlagTriangleRight, Hammer, Home, Keyboard, MessagesSquareIcon, Split, HelpCircle } from "lucide-react";
export const Icons = {
'start': Home,
'input': Keyboard,
Expand All @@ -11,7 +11,7 @@ export const Icons = {
'agent': Bot,
'end': FlagTriangleRight,
'condition': Split,
'report': FileDown
'report': FileDown,
}
export const Colors = {
'start': 'bg-[#FFD89A]',
Expand All @@ -24,7 +24,7 @@ export const Colors = {
'agent': 'bg-[#FFD89A]',
'end': 'bg-red-400',
'condition': 'bg-[#EDC9E9]',
'report': 'bg-[#9CE4F4]'
'report': 'bg-[#9CE4F4]',
}

export default function NodeLogo({ type, className = '', colorStr = '' }) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import HistoryNumItem from "./component/HistoryNumItem";
import InputFormItem from "./component/InputFormItem";
import InputItem from "./component/InputItem";
import InputListItem from "./component/InputListItem";
import InputPredictionConfigItem from "./component/InputPredictionConfigItem";
import KnowledgeQaSelectItem from "./component/KnowledgeQaSelectItem";
import KnowledgeSelectItem from "./component/KnowledgeSelectItem";
import ModelItem from "./component/ModelItem";
Expand Down Expand Up @@ -138,6 +139,8 @@ export default function Parameter({ node, nodeId, item, onOutPutChange, onStatus
return <ReportItem nodeId={nodeId} data={item} onChange={handleOnNewValue} onValidate={bindValidate} />;
case 'sql_config':
return <SqlConfigItem nodeId={nodeId} data={item} onChange={handleOnNewValue} onValidate={bindValidate} />;
case 'input_prediction_config':
return <InputPredictionConfigItem data={item} onChange={handleOnNewValue} onValidate={bindValidate} />;
case 'select_fileaccept':
return <FileTypeSelect data={item} onChange={(val) => {
// group_params[0] 受input模板影响
Expand Down
Loading