Skip to content

Commit cfd9445

Browse files
committed
重构: 简化ChatMemoryStore、ChatService, SysMessageService之间的职责。
1 parent 2efea17 commit cfd9445

File tree

6 files changed

+144
-186
lines changed

6 files changed

+144
-186
lines changed

src/main/java/com/xiaozhi/communication/common/ChatSession.java

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
import com.xiaozhi.entity.SysRole;
88
import com.xiaozhi.enums.ListenMode;
99
import lombok.Data;
10-
import org.springframework.ai.chat.memory.ChatMemory;
11-
import org.springframework.ai.chat.messages.Message;
1210
import org.springframework.ai.tool.ToolCallback;
1311
import reactor.core.publisher.Sinks;
1412

@@ -69,10 +67,7 @@ public abstract class ChatSession {
6967
* 会话的最后有效活动时间
7068
*/
7169
protected Instant lastActivityTime;
72-
/**
73-
* spring ai 聊天记忆
74-
*/
75-
protected ChatMemory chatMemory;
70+
7671
/**
7772
* 会话属性存储
7873
*/
@@ -129,17 +124,8 @@ public List<ToolCallback> getToolCallbacks() {
129124
return toolsSessionHolder.getAllFunction();
130125
}
131126

132-
public void clearMemory() {
133-
chatMemory.clear(sessionId);
134-
}
135127

136-
public List<Message> getHistoryMessages() {
137-
return chatMemory.get(sessionId);
138-
}
139128

140-
public void addHistoryMessage(Message message){
141-
chatMemory.add(sessionId, message);
142-
}
143129
/**
144130
* 会话连接是否打开中
145131
* @return

src/main/java/com/xiaozhi/communication/common/MessageHandler.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,8 @@ public void afterConnectionClosed(String sessionId) {
177177
audioService.cleanupSession(sessionId);
178178
// 清理对话
179179
dialogueService.cleanupSession(sessionId);
180+
// 清理ChatService缓存的对话历史。
181+
chatService.clearMessageCache(device.getDeviceId());
180182
}
181183

182184
/**

src/main/java/com/xiaozhi/dialogue/llm/ChatService.java

Lines changed: 21 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,7 @@ public class ChatService {
6767
// 新句子判断的字符阈值
6868
private static final int NEW_SENTENCE_TOKEN_THRESHOLD = 8;
6969

70-
// 历史记录默认限制数量
71-
private static final int DEFAULT_HISTORY_LIMIT = 10;
70+
7271

7372
@Resource(name = "messageWindowChatMemory")
7473
private ChatMemoryStore chatMemoryStore;
@@ -92,52 +91,27 @@ public String chat(ChatSession session, String message, boolean useFunctionCall)
9291
// 获取ChatModel
9392
ChatModel chatModel = chatModelFactory.takeChatModel(session);
9493

95-
if (session.getChatMemory() == null) {// 如果记忆没初始化,则初始化一下
96-
initializeHistory(session);
97-
}
98-
99-
// 获取格式化的历史记录(包含当前用户消息)
100-
List<Message> historyMessages = session.getHistoryMessages();
101-
10294
ChatOptions chatOptions = ToolCallingChatOptions.builder()
10395
.toolCallbacks(useFunctionCall ? session.getToolCallbacks() : new ArrayList<>())
10496
.toolContext(TOOL_CONTEXT_SESSION_KEY, session)
10597
.build();
106-
UserMessage userMessage = new UserMessage(message);
107-
Prompt prompt = Prompt.builder().messages(historyMessages).messages(userMessage).chatOptions(chatOptions)
108-
.build();
98+
99+
UserMessage userMessage = new UserMessage( message);
100+
Prompt prompt = new Prompt(chatMemoryStore.prompt(device, userMessage),chatOptions);
109101

110102
ChatResponse chatResponse = chatModel.call(prompt);
111103
if (chatResponse == null || chatResponse.getResult().getOutput().getText() == null) {
112104
logger.warn("模型响应为空或无生成内容");
113105
return "抱歉,我在处理您的请求时遇到了问题。请稍后再试。";
114106
}
115-
String response = chatResponse.getResult().getOutput().getText();
116-
boolean hasToolCalls = chatResponse.hasToolCalls();
117-
String messageType = SysMessage.MESSAGE_TYPE_NORMAL;// 默认消息类型为普通消息
118-
if (!hasToolCalls) {// 非function消息才加入对话历史,避免调用混乱
119-
// 更新历史消息缓存
120-
session.addHistoryMessage(userMessage);
121-
if (response != null && !response.isEmpty()) {
122-
session.addHistoryMessage(new AssistantMessage(response));
123-
}
124-
} else {
125-
// TODO 后续还需要根据元数据判断是function_call还是mcp调用
126-
// 检查元数据中是否包含工具调用标识
127-
// 发生了工具调用,获取函数调用的名称,通过名称反查类型
128-
// String functionName = chatResponse.getMetadata().get("function_name");
129-
messageType = SysMessage.MESSAGE_TYPE_FUNCTION_CALL;// function消息类型
130-
}
131-
final String finalMessageType = messageType;
107+
AssistantMessage assistantMessage =chatResponse.getResult().getOutput();
108+
132109
Thread.startVirtualThread(() -> {// 异步持久化
133-
// 保存用户消息,会被持久化至数据库。
134-
chatMemoryStore.addUserMessage(device, message, finalMessageType);
135-
if (response != null && !response.isEmpty()) {
136-
// 保存AI消息,会被持久化至数据库。
137-
chatMemoryStore.addAssistantMessage(device, response, finalMessageType);
138-
}
110+
111+
// 保存AI消息,会被持久化至数据库。
112+
chatMemoryStore.addMessage(device, assistantMessage,null);
139113
});
140-
return response;
114+
return assistantMessage.getText();
141115

142116
} catch (Exception e) {
143117
logger.error("处理查询时出错: {}", e.getMessage(), e);
@@ -162,21 +136,10 @@ public Flux<ChatResponse> chatStream(ChatSession session, SysDevice device, Stri
162136
.toolContext(TOOL_CONTEXT_SESSION_KEY, session)
163137
.build();
164138

165-
if (session.getChatMemory() == null) {// 如果记忆没初始化,则初始化一下
166-
initializeHistory(session);
167-
}
168-
// 获取格式化的历史记录(包含当前用户消息)
169-
List<Message> historyMessages = session.getHistoryMessages();
170-
171139
UserMessage userMessage = new UserMessage(message);
172-
historyMessages.add(userMessage);
173-
Prompt prompt = Prompt.builder().messages(historyMessages).chatOptions(chatOptions).build();
140+
Prompt prompt = new Prompt(chatMemoryStore.prompt(device, userMessage),chatOptions);
174141

175142
// 调用实际的流式聊天方法
176-
// return chatModel.stream(prompt).map(response -> (response.getResult() == null
177-
// || response.getResult().getOutput() == null
178-
// || response.getResult().getOutput().getText() == null) ? ""
179-
// : response.getResult().getOutput().getText());
180143
return chatModel.stream(prompt);
181144
}
182145

@@ -230,19 +193,9 @@ public void initializeHistory(ChatSession chatSession) {
230193
return;
231194
}
232195
SysDevice device = chatSession.getSysDevice();
233-
// 如果缓存中不存在该设备的历史记录,则初始化缓存。默认情况下,只缓存当前会话的聊天记录。
234-
// 同一个设备重新连接至服务器,会被标识为不同的sessionId。
235-
// 可以将这理解为spring-ai的conversation会话,将sessionId作为conversationId
236-
// 从数据库加载历史记录
237-
List<SysMessage> history = chatMemoryStore.getMessages(device.getDeviceId(),
238-
SysMessage.MESSAGE_TYPE_NORMAL, DEFAULT_HISTORY_LIMIT);
239-
String systemMessage = chatMemoryStore.getSystemMessage(device.getDeviceId(), device.getRoleId());
240-
MessageWindowChatMemory chatMemory = MessageWindowChatMemory.builder()
241-
.maxMessages(DEFAULT_HISTORY_LIMIT)
242-
.build();// 创建一个新的MessageWindowChatMemory实例,限制为10条消息滚动
243-
chatMemory.add(device.getSessionId(), new SystemMessage(systemMessage));
244-
chatMemory.add(device.getSessionId(), convert(history));
245-
chatSession.setChatMemory(chatMemory);
196+
197+
// 从数据库加载历史记录,初始化缓存。
198+
List<Message> history = chatMemoryStore.initHistory(device.getDeviceId());
246199
logger.info("已初始化设备 {} 的历史记录缓存,共 {} 条消息", device.getDeviceId(), history.size());
247200
}
248201

@@ -255,46 +208,6 @@ public void clearMessageCache(String deviceId) {
255208
chatMemoryStore.clearMessages(deviceId);
256209
}
257210

258-
259-
/**
260-
* 通用添加消息
261-
*
262-
* @param message 消息内容
263-
* @param role 角色名称
264-
* @param messageType 消息类型
265-
*/
266-
public void addMessage(SysDevice device, String message, String role, String messageType, String audioPath) {
267-
chatMemoryStore.addMessage(device.getDeviceId(), device.getSessionId(), role, message, device.getRoleId(),
268-
messageType, audioPath);
269-
}
270-
271-
/**
272-
* 将数据库记录的SysMessag转换为spring-ai的Message。
273-
* 加载的历史都是普通消息(SysMessage.MESSAGE_TYPE_NORMAL)
274-
*
275-
* @param messages
276-
* @return
277-
*/
278-
private List<Message> convert(List<SysMessage> messages) {
279-
if (messages == null || messages.isEmpty()) {
280-
return Collections.emptyList();
281-
}
282-
return messages.stream()
283-
.filter(message -> MessageType.ASSISTANT.getValue().equals(message.getSender())
284-
|| MessageType.USER.getValue().equals(message.getSender()))
285-
.map(message -> {
286-
String role = message.getSender();
287-
// 一般消息("messageType", "NORMAL");//默认为普通消息
288-
Map<String, Object> metadata = Map.of("messageId", message.getMessageId(), "messageType",
289-
message.getMessageType());
290-
return switch (role) {
291-
case "assistant" -> new AssistantMessage(message.getMessage(), metadata);
292-
case "user" -> UserMessage.builder().text(message.getMessage()).metadata(metadata).build();
293-
default -> throw new IllegalArgumentException("Invalid role: " + role);
294-
};
295-
}).collect(Collectors.toList());
296-
}
297-
298211
/**
299212
* 判断文本是否包含实质性内容(不仅仅是空白字符或标点符号)
300213
*
@@ -458,30 +371,21 @@ void persistMessages(String toolName) {
458371
// TODO
459372
// 需要进一步看看ChatModel在流式响应里是如何判断hasTools的,或者直接基于Flux<ChatResponse>已封装好的对象hasToolCalls判断
460373
boolean hasToolCalls = toolName != null && !toolName.isEmpty();
461-
String messageType = hasToolCalls ? SysMessage.MESSAGE_TYPE_FUNCTION_CALL : SysMessage.MESSAGE_TYPE_NORMAL;// TODO
462-
// 后续可以根据名称区分function还是mcp,来细分类型
374+
String messageType = hasToolCalls ? SysMessage.MESSAGE_TYPE_FUNCTION_CALL : SysMessage.MESSAGE_TYPE_NORMAL;
375+
// TODO
376+
// 后续可以根据名称区分function还是mcp,来细分类型
463377

464378
UserMessage userMessage = new UserMessage(message);
465379

466-
// 获取当前对话ID
467-
String dialogueId = session.getDialogueId();
468-
469-
if (!hasToolCalls) {// 非function消息才加入对话历史,避免调用混乱
470-
session.addHistoryMessage(userMessage);
471-
if (!fullResponse.isEmpty()) {
472-
AssistantMessage assistantMessage = new AssistantMessage(fullResponse.toString());
473-
session.addHistoryMessage(assistantMessage);
474-
}
475-
}
476380
Thread.startVirtualThread(() -> {// 异步持久化
477381
String userAudioPath = session.getUserAudioPath();
478-
addMessage(session.getSysDevice(), userMessage.getText(), userMessage.getMessageType().getValue(),
479-
messageType, userAudioPath);
382+
chatMemoryStore.addMessage(session.getSysDevice(), userMessage, userAudioPath);
383+
480384
if (!fullResponse.isEmpty()) {
481385
AssistantMessage assistantMessage = new AssistantMessage(fullResponse.toString());
482386
String assistAudioPath = session.getAssistantAudioPath();
483-
addMessage(session.getSysDevice(), assistantMessage.getText(),
484-
assistantMessage.getMessageType().getValue(), messageType, assistAudioPath);
387+
388+
chatMemoryStore.addMessage(session.getSysDevice(), assistantMessage, assistAudioPath);
485389
}
486390
});
487391
}

src/main/java/com/xiaozhi/dialogue/llm/memory/ChatMemoryStore.java

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,41 +2,33 @@
22

33
import com.xiaozhi.entity.SysDevice;
44
import com.xiaozhi.entity.SysMessage;
5+
import org.springframework.ai.chat.messages.AssistantMessage;
6+
import org.springframework.ai.chat.messages.Message;
7+
import org.springframework.ai.chat.messages.UserMessage;
8+
59
import java.util.List;
610

711
/**
812
* 聊天记忆接口
913
* 负责管理聊天历史记录
1014
*/
1115
public interface ChatMemoryStore {
12-
13-
/**
14-
* 添加消息到历史记录
15-
*
16-
* @param deviceId 设备ID
17-
* @param sessionId 会话ID
18-
* @param sender 发送者
19-
* @param content 内容
20-
* @param roleId 角色ID
21-
* @param messageType 消息类型
22-
*/
23-
void addMessage(String deviceId, String sessionId, String sender, String content, Integer roleId, String messageType, String audioPath);
16+
// 历史记录默认限制数量
17+
int DEFAULT_HISTORY_LIMIT = 10;
2418

25-
// TODO 最终要去掉
26-
void addUserMessage(SysDevice device, String message, String messageType);
19+
void addMessage(SysDevice device, UserMessage message, String audioPath);
2720

28-
void addAssistantMessage(SysDevice device, String message, String messageType);
21+
void addMessage(SysDevice device, AssistantMessage message, String audioPath);
2922

23+
List<Message> prompt(SysDevice device,UserMessage userMessage);
3024

3125
/**
32-
* 获取历史消息
26+
* 初始化历史消息
3327
*
34-
* @param deviceId 设备ID
35-
* @param messageType 指定查询的消息类型 - 传null查所有消息
36-
* @param limit 消息数量限制
28+
* @param deviceId 设备标识
3729
* @return 历史消息列表
3830
*/
39-
List<SysMessage> getMessages(String deviceId, String messageType, Integer limit);
31+
List<Message> initHistory(String deviceId);
4032

4133
/**
4234
* 清除设备的历史记录

src/main/java/com/xiaozhi/dialogue/llm/memory/DatabaseChatMemory.java

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -33,25 +33,28 @@ public class DatabaseChatMemory {
3333
private TtsServiceFactory ttsService;
3434

3535
public void addMessage(String deviceId, String sessionId, String sender, String content, Integer roleId, String messageType, String audioPath) {
36-
try {
37-
SysMessage message = new SysMessage();
38-
message.setDeviceId(deviceId);
39-
message.setSessionId(sessionId);
40-
message.setSender(sender);
41-
message.setMessage(content);
42-
message.setRoleId(roleId);
43-
message.setMessageType(messageType);
44-
if (sender == "assistant") {
45-
// 目前生成的语音保存采用默认的语音合成服务,后续可以考虑支持自定义语音合成服务
46-
// todo
47-
message.setAudioPath(ttsService.getDefaultTtsService().textToSpeech(content));
48-
} else {
49-
message.setAudioPath(audioPath);
36+
// TODO 异步虚拟线程处理持久化。
37+
Thread.startVirtualThread(() -> {
38+
try {
39+
SysMessage message = new SysMessage();
40+
message.setDeviceId(deviceId);
41+
message.setSessionId(sessionId);
42+
message.setSender(sender);
43+
message.setMessage(content);
44+
message.setRoleId(roleId);
45+
message.setMessageType(messageType);
46+
if (sender == "assistant") {
47+
// 目前生成的语音保存采用默认的语音合成服务,后续可以考虑支持自定义语音合成服务
48+
// todo
49+
message.setAudioPath(ttsService.getDefaultTtsService().textToSpeech(content));
50+
} else {
51+
message.setAudioPath(audioPath);
52+
}
53+
messageService.add(message);
54+
} catch (Exception e) {
55+
logger.error("保存消息时出错: {}", e.getMessage(), e);
5056
}
51-
messageService.add(message);
52-
} catch (Exception e) {
53-
logger.error("保存消息时出错: {}", e.getMessage(), e);
54-
}
57+
});
5558
}
5659

5760

0 commit comments

Comments
 (0)