Skip to content

Commit 8ed9c2e

Browse files
committed
重构: 进一步抽象出一个Conversation的接口,负责短期单个对话的记忆。配合原有的ChatMemory接口,ChatMemory则是负责全局记忆的存储策略及针对不同类型数据库的适配。
1 parent b4a6202 commit 8ed9c2e

File tree

10 files changed

+351
-277
lines changed

10 files changed

+351
-277
lines changed

src/main/java/com/xiaozhi/communication/common/ChatSession.java

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package com.xiaozhi.communication.common;
22

33
import com.xiaozhi.communication.domain.iot.IotDescriptor;
4+
import com.xiaozhi.dialogue.llm.memory.Conversation;
45
import com.xiaozhi.dialogue.llm.tool.ToolsSessionHolder;
56
import com.xiaozhi.dialogue.llm.tool.mcp.device.DeviceMcpHolder;
67
import com.xiaozhi.entity.SysDevice;
@@ -30,6 +31,11 @@ public abstract class ChatSession {
3031
* 设备可用角色列表
3132
*/
3233
protected List<SysRole> sysRoleList;
34+
/**
35+
* 一个Session在某个时刻,只有一个活跃的Conversation。
36+
* 当切换角色时,Conversation应该释放新建。切换角色一般是不频繁的。
37+
*/
38+
protected Conversation conversation;
3339
/**
3440
* 设备iot信息
3541
*/
@@ -142,4 +148,21 @@ public List<ToolCallback> getToolCallbacks() {
142148
public abstract void sendTextMessage(String message);
143149

144150
public abstract void sendBinaryMessage(byte[] message);
151+
152+
/**
153+
* 设置 Conversation,需要与当前活跃角色一致。
154+
* 当切换角色时,会释放当前 Conversation,并新建一个对应于新角色的Conversation。
155+
* @param conversation
156+
*/
157+
public void setConversation( Conversation conversation) {
158+
this.conversation = conversation;
159+
}
160+
161+
/**
162+
* 获取与当前活跃角色一致的 Conversation。
163+
* @return
164+
*/
165+
public Conversation getConversation() {
166+
return conversation;
167+
}
145168
}

src/main/java/com/xiaozhi/communication/common/MessageHandler.java

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
package com.xiaozhi.communication.common;
22

33
import com.xiaozhi.communication.domain.*;
4-
import com.xiaozhi.dialogue.llm.ChatService;
54
import com.xiaozhi.dialogue.llm.factory.ChatModelFactory;
5+
import com.xiaozhi.dialogue.llm.memory.ChatMemory;
6+
import com.xiaozhi.dialogue.llm.memory.Conversation;
67
import com.xiaozhi.dialogue.llm.tool.ToolsGlobalRegistry;
78
import com.xiaozhi.dialogue.llm.tool.ToolsSessionHolder;
89
import com.xiaozhi.dialogue.service.AudioService;
@@ -65,7 +66,7 @@ public class MessageHandler {
6566
private SttServiceFactory sttFactory;
6667

6768
@Resource
68-
private ChatService chatService;
69+
private ChatMemory chatMemory;
6970

7071
@Resource
7172
private ChatModelFactory chatModelFactory;
@@ -105,6 +106,7 @@ public void afterConnection(ChatSession chatSession, String deviceIdAuth) {
105106
//以上同步处理结束后,再启动虚拟线程进行设备初始化,确保chatSession中已设置的sysDevice信息
106107
Thread.startVirtualThread(() -> {
107108
try {
109+
// 从数据库获取角色描述。device.getRoleId()表示当前设备的当前活跃角色,或者上次退出时的活跃角色。
108110
SysRole role = roleService.selectRoleById(device.getRoleId());
109111

110112
if (role.getSttId() != null) {
@@ -121,7 +123,8 @@ public void afterConnection(ChatSession chatSession, String deviceIdAuth) {
121123
}
122124
if (role.getModelId() != null) {
123125
chatModelFactory.takeChatModel(chatSession);// 提前初始化,加速后续使用
124-
chatService.initializeHistory(chatSession);
126+
Conversation conversation = chatMemory.initConversation(device, role, sessionId);
127+
chatSession.setConversation( conversation);
125128
// 注册全局函数
126129
toolsSessionHolder.registerGlobalFunctionTools(chatSession);
127130
}
@@ -178,7 +181,7 @@ public void afterConnectionClosed(String sessionId) {
178181
// 清理对话
179182
dialogueService.cleanupSession(sessionId);
180183
// 清理ChatService缓存的对话历史。
181-
chatService.clearMessageCache(device.getDeviceId());
184+
chatMemory.clearMessages(device.getDeviceId());
182185
}
183186

184187
/**

src/main/java/com/xiaozhi/dialogue/llm/ChatService.java

Lines changed: 18 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,17 @@
33
import com.xiaozhi.communication.common.ChatSession;
44
import com.xiaozhi.dialogue.llm.api.StreamResponseListener;
55
import com.xiaozhi.dialogue.llm.factory.ChatModelFactory;
6-
import com.xiaozhi.dialogue.llm.memory.ChatMemoryStore;
6+
import com.xiaozhi.dialogue.llm.memory.Conversation;
7+
import com.xiaozhi.dialogue.llm.memory.DatabaseChatMemory;
8+
import com.xiaozhi.dialogue.llm.memory.MessageWindowConversation;
79
import com.xiaozhi.entity.SysDevice;
810
import com.xiaozhi.entity.SysMessage;
11+
import com.xiaozhi.entity.SysRole;
12+
import com.xiaozhi.service.SysRoleService;
913
import com.xiaozhi.utils.EmojiUtils;
1014
import jakarta.annotation.Resource;
1115
import org.slf4j.Logger;
1216
import org.slf4j.LoggerFactory;
13-
import org.springframework.ai.chat.memory.MessageWindowChatMemory;
1417
import org.springframework.ai.chat.messages.*;
1518
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
1619
import org.springframework.ai.chat.model.ChatModel;
@@ -23,14 +26,13 @@
2326
import reactor.core.publisher.Flux;
2427

2528
import java.util.ArrayList;
26-
import java.util.Collections;
2729
import java.util.List;
28-
import java.util.Map;
2930
import java.util.concurrent.atomic.AtomicBoolean;
3031
import java.util.concurrent.atomic.AtomicInteger;
3132
import java.util.regex.Matcher;
3233
import java.util.regex.Pattern;
33-
import java.util.stream.Collectors;
34+
35+
import static com.xiaozhi.dialogue.llm.memory.MessageWindowConversation.DEFAULT_HISTORY_LIMIT;
3436

3537
/**
3638
*
@@ -67,15 +69,15 @@ public class ChatService {
6769
// 新句子判断的字符阈值
6870
private static final int NEW_SENTENCE_TOKEN_THRESHOLD = 8;
6971

70-
71-
72-
@Resource(name = "messageWindowChatMemory")
73-
private ChatMemoryStore chatMemoryStore;
72+
@Resource
73+
private DatabaseChatMemory chatMemoryStore;
7474

7575
// TODO 移到构建者模式,由连接通过认证,可正常对话时,创建实例,构建好一个完整的Role.
7676
@Resource
7777
private ChatModelFactory chatModelFactory;
7878

79+
@Resource
80+
private SysRoleService roleService;
7981
/**
8082
* 处理用户查询(同步方式)
8183
*
@@ -97,7 +99,8 @@ public String chat(ChatSession session, String message, boolean useFunctionCall)
9799
.build();
98100

99101
UserMessage userMessage = new UserMessage( message);
100-
Prompt prompt = new Prompt(chatMemoryStore.prompt(device, userMessage),chatOptions);
102+
List<Message> messages = session.getConversation().prompt( userMessage);
103+
Prompt prompt = new Prompt(messages,chatOptions);
101104

102105
ChatResponse chatResponse = chatModel.call(prompt);
103106
if (chatResponse == null || chatResponse.getResult().getOutput().getText() == null) {
@@ -109,7 +112,7 @@ public String chat(ChatSession session, String message, boolean useFunctionCall)
109112
Thread.startVirtualThread(() -> {// 异步持久化
110113

111114
// 保存AI消息,会被持久化至数据库。
112-
chatMemoryStore.addMessage(device, assistantMessage,null);
115+
session.getConversation().addMessage( assistantMessage,null);
113116
});
114117
return assistantMessage.getText();
115118

@@ -137,7 +140,8 @@ public Flux<ChatResponse> chatStream(ChatSession session, SysDevice device, Stri
137140
.build();
138141

139142
UserMessage userMessage = new UserMessage(message);
140-
Prompt prompt = new Prompt(chatMemoryStore.prompt(device, userMessage),chatOptions);
143+
List<Message> messages = session.getConversation().prompt( userMessage);
144+
Prompt prompt = new Prompt(messages,chatOptions);
141145

142146
// 调用实际的流式聊天方法
143147
return chatModel.stream(prompt);
@@ -184,20 +188,6 @@ public void chatStreamBySentence(ChatSession session, String message, boolean us
184188
}
185189
}
186190

187-
/**
188-
* 初始化设备的历史记录缓存
189-
*
190-
*/
191-
public void initializeHistory(ChatSession chatSession) {
192-
if (chatSession.getSysDevice() == null) {
193-
return;
194-
}
195-
SysDevice device = chatSession.getSysDevice();
196-
197-
// 从数据库加载历史记录,初始化缓存。
198-
List<Message> history = chatMemoryStore.initHistory(device.getDeviceId());
199-
logger.info("已初始化设备 {} 的历史记录缓存,共 {} 条消息", device.getDeviceId(), history.size());
200-
}
201191

202192
/**
203193
* 清除设备缓存
@@ -379,13 +369,13 @@ void persistMessages(String toolName) {
379369

380370
Thread.startVirtualThread(() -> {// 异步持久化
381371
String userAudioPath = session.getUserAudioPath();
382-
chatMemoryStore.addMessage(session.getSysDevice(), userMessage, userAudioPath);
372+
session.getConversation().addMessage( userMessage, userAudioPath);
383373

384374
if (!fullResponse.isEmpty()) {
385375
AssistantMessage assistantMessage = new AssistantMessage(fullResponse.toString());
386376
String assistAudioPath = session.getAssistantAudioPath();
387377

388-
chatMemoryStore.addMessage(session.getSysDevice(), assistantMessage, assistAudioPath);
378+
session.getConversation().addMessage(assistantMessage, assistAudioPath);
389379
}
390380
});
391381
}
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
package com.xiaozhi.dialogue.llm.memory;
2+
3+
import com.xiaozhi.entity.SysDevice;
4+
import com.xiaozhi.entity.SysMessage;
5+
import com.xiaozhi.entity.SysRole;
6+
7+
import java.util.List;
8+
9+
/**
10+
* 聊天记忆接口,全局对象,不针对单个会话,而是负责全局记忆的存储策略及针对不同类型数据库的适配。。
11+
* 不同于SysMessageService,此接口应该是一个更高的抽象层,更多是负责存储策略而并非底层存储的增删改查。
12+
* 已经参考了spring ai 的ChatMemory接口,暂时放弃spring ai 的ChatMemory。
13+
* 以后使用ChatClient与Advisor时直接实现一个更本地友好的ChatMemoryAdvisor。
14+
* Conversation则是参考了 langchain4j 的ChatMemory。
15+
*
16+
*/
17+
public interface ChatMemory {
18+
19+
/**
20+
* 不同的ChatMemory实现类,可以有不同的处理策略,可以初始化不同的Conversation子类。
21+
*
22+
* @param device 设备
23+
* @param role 角色
24+
* @param sessionId 会话ID
25+
* @return 会话
26+
*/
27+
Conversation initConversation(SysDevice device, SysRole role, String sessionId);
28+
// TODO 考虑将参数以Message对象传递,而不是多个参数。在再下一层Service层转换sparing ai的Message为Mapper需要的SysMessage对象
29+
void addMessage(String deviceId, String sessionId, String sender, String content, Integer roleId, String messageType, String audioPath);
30+
31+
List<SysMessage> getMessages(String deviceId, String messageType, Integer limit);
32+
33+
/**
34+
* 清除设备的历史记录
35+
*
36+
* @param deviceId 设备ID
37+
*/
38+
void clearMessages(String deviceId);
39+
40+
}

src/main/java/com/xiaozhi/dialogue/llm/memory/ChatMemoryStore.java

Lines changed: 0 additions & 58 deletions
This file was deleted.
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
package com.xiaozhi.dialogue.llm.memory;
2+
3+
import com.xiaozhi.entity.SysDevice;
4+
import com.xiaozhi.entity.SysRole;
5+
import org.springframework.ai.chat.messages.AssistantMessage;
6+
import org.springframework.ai.chat.messages.Message;
7+
import org.springframework.ai.chat.messages.UserMessage;
8+
9+
import java.util.List;
10+
11+
/**
12+
* Conversation 是一个 对应于 sys_message 表的,但高于 sys_message 的一个抽象实体。
13+
* deviceID, roleID, sessionID, 实质构成了一次Conversation的全局唯一ID。这个ID必须final 的。
14+
* 在关系型数据库里,可以将deviceID, roleID, sessionID 建一个组合索引,注意顺序sessionID放在最后。
15+
* 在图数据库里, conversation label的节点,连接 device节点、role节点。
16+
* deviceID与roleID本质上不是Conversation的真正属性,而是外键,代表连接的2个对象。
17+
* 只有sessionID是真正挂在Conversation的属性。
18+
*
19+
*/
20+
public abstract class Conversation {
21+
private final SysDevice device;
22+
private final SysRole role;
23+
private final String sessionId;
24+
25+
private List<Message> messages;
26+
27+
public Conversation(SysDevice device, SysRole role, String sessionId, List<Message> messages) {
28+
this.device = device;
29+
this.role = role;
30+
this.sessionId = sessionId;
31+
this.messages = messages;
32+
}
33+
34+
public SysDevice device() {
35+
return device;
36+
}
37+
public SysRole role() {
38+
return role;
39+
}
40+
41+
public String sessionId() {
42+
return sessionId;
43+
}
44+
45+
public List<Message> messages() {
46+
return messages;
47+
}
48+
49+
abstract public void clear();
50+
51+
abstract public void addMessage(UserMessage message, String audioPath);
52+
53+
abstract public void addMessage(AssistantMessage message, String audioPath);
54+
55+
abstract public List<Message> prompt(UserMessage userMessage);
56+
}

0 commit comments

Comments
 (0)