Skip to content

Commit 2efea17

Browse files
committed
重构: 简化ChatMemory、ChatService, SysMessageService之间的职责。
1 parent d6df8f4 commit 2efea17

File tree

4 files changed

+158
-74
lines changed

4 files changed

+158
-74
lines changed

src/main/java/com/xiaozhi/dialogue/llm/ChatService.java

Lines changed: 3 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ public class ChatService {
7070
// 历史记录默认限制数量
7171
private static final int DEFAULT_HISTORY_LIMIT = 10;
7272

73-
@Resource
73+
@Resource(name = "messageWindowChatMemory")
7474
private ChatMemoryStore chatMemoryStore;
7575

7676
// TODO 移到构建者模式,由连接通过认证,可正常对话时,创建实例,构建好一个完整的Role.
@@ -131,10 +131,10 @@ public String chat(ChatSession session, String message, boolean useFunctionCall)
131131
final String finalMessageType = messageType;
132132
Thread.startVirtualThread(() -> {// 异步持久化
133133
// 保存用户消息,会被持久化至数据库。
134-
this.addUserMessage(device, message, finalMessageType);
134+
chatMemoryStore.addUserMessage(device, message, finalMessageType);
135135
if (response != null && !response.isEmpty()) {
136136
// 保存AI消息,会被持久化至数据库。
137-
this.addAssistantMessage(device, response, finalMessageType);
137+
chatMemoryStore.addAssistantMessage(device, response, finalMessageType);
138138
}
139139
});
140140
return response;
@@ -255,26 +255,6 @@ public void clearMessageCache(String deviceId) {
255255
chatMemoryStore.clearMessages(deviceId);
256256
}
257257

258-
/**
259-
* 添加用户消息
260-
*
261-
* @param message 用户消息
262-
*/
263-
public void addUserMessage(SysDevice device, String message, String messageType) {
264-
// 更新缓存
265-
chatMemoryStore.addMessage(device.getDeviceId(), device.getSessionId(), "user", message,
266-
device.getRoleId(), messageType, null);
267-
}
268-
269-
/**
270-
* 添加AI消息
271-
*
272-
* @param message AI消息
273-
*/
274-
public void addAssistantMessage(SysDevice device, String message, String messageType) {
275-
chatMemoryStore.addMessage(device.getDeviceId(), device.getSessionId(), "assistant", message,
276-
device.getRoleId(), messageType, null);
277-
}
278258

279259
/**
280260
* 通用添加消息

src/main/java/com/xiaozhi/dialogue/llm/memory/ChatMemoryStore.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package com.xiaozhi.dialogue.llm.memory;
22

3+
import com.xiaozhi.entity.SysDevice;
34
import com.xiaozhi.entity.SysMessage;
45
import java.util.List;
56

@@ -20,7 +21,13 @@ public interface ChatMemoryStore {
2021
* @param messageType 消息类型
2122
*/
2223
void addMessage(String deviceId, String sessionId, String sender, String content, Integer roleId, String messageType, String audioPath);
23-
24+
25+
// TODO 最终要去掉
26+
void addUserMessage(SysDevice device, String message, String messageType);
27+
28+
void addAssistantMessage(SysDevice device, String message, String messageType);
29+
30+
2431
/**
2532
* 获取历史消息
2633
*

src/main/java/com/xiaozhi/dialogue/llm/memory/DatabaseChatMemory.java

Lines changed: 3 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -23,22 +23,15 @@
2323
* 基于数据库的聊天记忆实现
2424
*/
2525
@Service
26-
public class DatabaseChatMemory implements ChatMemoryStore {
26+
public class DatabaseChatMemory {
2727
private static final Logger logger = LoggerFactory.getLogger(DatabaseChatMemory.class);
2828

2929
@Autowired
3030
private SysMessageService messageService;
3131

32-
@Autowired
33-
private SysRoleService roleService;
34-
3532
@Autowired
3633
private TtsServiceFactory ttsService;
3734

38-
// 缓存系统消息,避免频繁查询数据库
39-
private Map<String, String> systemMessageCache = new ConcurrentHashMap<>();
40-
41-
@Override
4235
public void addMessage(String deviceId, String sessionId, String sender, String content, Integer roleId, String messageType, String audioPath) {
4336
try {
4437
SysMessage message = new SysMessage();
@@ -61,7 +54,7 @@ public void addMessage(String deviceId, String sessionId, String sender, String
6154
}
6255
}
6356

64-
@Override
57+
6558
public List<SysMessage> getMessages(String deviceId, String messageType, Integer limit) {
6659
try {
6760
SysMessage queryMessage = new SysMessage();
@@ -79,56 +72,16 @@ public List<SysMessage> getMessages(String deviceId, String messageType, Integer
7972
}
8073
}
8174

82-
@Override
75+
8376
public void clearMessages(String deviceId) {
8477
try {
8578
// 清除设备的历史消息
8679
SysMessage deleteMessage = new SysMessage();
8780
deleteMessage.setDeviceId(deviceId);
8881
// messageService.update(deleteMessage);
89-
90-
// 清除缓存
91-
systemMessageCache.keySet().removeIf(key -> key.startsWith(deviceId + ":"));
9282
} catch (Exception e) {
9383
logger.error("清除设备历史记录时出错: {}", e.getMessage(), e);
9484
}
9585
}
9686

97-
@Override
98-
public String getSystemMessage(String deviceId, Integer roleId) {
99-
100-
if (roleId == null) {
101-
return "";
102-
}
103-
String cacheKey = deviceId + ":" + roleId;
104-
105-
// 先从缓存获取
106-
if (systemMessageCache.containsKey(cacheKey)) {
107-
return systemMessageCache.get(cacheKey);
108-
}
109-
110-
try {
111-
// 从数据库获取角色描述
112-
SysRole role = roleService.selectRoleById(roleId);
113-
if (role != null && role.getRoleDesc() != null) {
114-
String systemMessage = role.getRoleDesc();
115-
// 存入缓存
116-
systemMessageCache.put(cacheKey, systemMessage);
117-
return systemMessage;
118-
}
119-
} catch (Exception e) {
120-
logger.error("获取系统消息时出错: {}", e.getMessage(), e);
121-
}
122-
123-
return "";
124-
}
125-
126-
@Override
127-
public void setSystemMessage(String deviceId, Integer roleId, String systemMessage) {
128-
String cacheKey = deviceId + ":" + roleId;
129-
130-
// 更新缓存
131-
systemMessageCache.put(cacheKey, systemMessage);
132-
133-
}
13487
}
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
package com.xiaozhi.dialogue.llm.memory;
2+
3+
import com.xiaozhi.entity.SysDevice;
4+
import com.xiaozhi.entity.SysMessage;
5+
import com.xiaozhi.entity.SysRole;
6+
import com.xiaozhi.service.SysRoleService;
7+
import org.springframework.ai.chat.messages.Message;
8+
import org.springframework.ai.chat.messages.SystemMessage;
9+
import org.springframework.beans.factory.annotation.Autowired;
10+
import org.springframework.stereotype.Service;
11+
12+
import java.util.*;
13+
import java.util.concurrent.ConcurrentHashMap;
14+
15+
@Service(value = "messageWindowChatMemory")
16+
public class MessageWindowChatMemory implements ChatMemoryStore{
17+
private final SysRoleService roleService;
18+
private final DatabaseChatMemory chatMemory;
19+
private final int maxMessages;
20+
private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(MessageWindowChatMemory.class);
21+
// 缓存系统消息,避免频繁查询数据库
22+
private Map<String, String> systemMessageCache = new ConcurrentHashMap<>();
23+
24+
25+
@Autowired
26+
public MessageWindowChatMemory(SysRoleService roleService,DatabaseChatMemory chatMemory,int maxMessages){
27+
this.roleService = roleService;
28+
this.chatMemory = chatMemory;
29+
this.maxMessages = maxMessages;
30+
}
31+
32+
@Override
33+
public void addMessage(String deviceId, String sessionId, String sender, String content, Integer roleId, String messageType, String audioPath) {
34+
// TODO 修改接口为 使用UserMessage 、 AssistantMessage的概念
35+
chatMemory.addMessage(deviceId, sessionId, sender, content, roleId, messageType, audioPath);
36+
}
37+
38+
@Override
39+
public List<SysMessage> getMessages(String deviceId, String messageType, Integer limit) {
40+
return chatMemory.getMessages(deviceId, messageType, limit);
41+
}
42+
43+
@Override
44+
public void clearMessages(String deviceId) {
45+
chatMemory.clearMessages(deviceId);
46+
// 清除缓存
47+
systemMessageCache.keySet().removeIf(key -> key.startsWith(deviceId + ":"));
48+
}
49+
50+
// todo 在get时处理缩容。
51+
private List<Message> process(List<Message> memoryMessages, List<Message> newMessages) {
52+
List<Message> processedMessages = new ArrayList<>();
53+
54+
Set<Message> memoryMessagesSet = new HashSet<>(memoryMessages);
55+
boolean hasNewSystemMessage = newMessages.stream()
56+
.filter(SystemMessage.class::isInstance)
57+
.anyMatch(message -> !memoryMessagesSet.contains(message));
58+
59+
memoryMessages.stream()
60+
.filter(message -> !(hasNewSystemMessage && message instanceof SystemMessage))
61+
.forEach(processedMessages::add);
62+
63+
processedMessages.addAll(newMessages);
64+
65+
if (processedMessages.size() <= this.maxMessages) {
66+
return processedMessages;
67+
}
68+
69+
int messagesToRemove = processedMessages.size() - this.maxMessages;
70+
71+
List<Message> trimmedMessages = new ArrayList<>();
72+
int removed = 0;
73+
for (Message message : processedMessages) {
74+
if (message instanceof SystemMessage || removed >= messagesToRemove) {
75+
trimmedMessages.add(message);
76+
}
77+
else {
78+
removed++;
79+
}
80+
}
81+
82+
return trimmedMessages;
83+
}
84+
85+
@Override
86+
public String getSystemMessage(String deviceId, Integer roleId) {
87+
88+
if (roleId == null) {
89+
return "";
90+
}
91+
String cacheKey = deviceId + ":" + roleId;
92+
93+
// 先从缓存获取
94+
if (systemMessageCache.containsKey(cacheKey)) {
95+
return systemMessageCache.get(cacheKey);
96+
}
97+
98+
try {
99+
// 从数据库获取角色描述
100+
SysRole role = roleService.selectRoleById(roleId);
101+
if (role != null && role.getRoleDesc() != null) {
102+
String systemMessage = role.getRoleDesc();
103+
// 存入缓存
104+
systemMessageCache.put(cacheKey, systemMessage);
105+
return systemMessage;
106+
}
107+
} catch (Exception e) {
108+
logger.error("获取系统消息时出错: {}", e.getMessage(), e);
109+
}
110+
111+
return "";
112+
}
113+
114+
@Override
115+
public void setSystemMessage(String deviceId, Integer roleId, String systemMessage) {
116+
String cacheKey = deviceId + ":" + roleId;
117+
118+
// 更新缓存
119+
systemMessageCache.put(cacheKey, systemMessage);
120+
121+
}
122+
123+
/**
124+
* 添加用户消息
125+
*
126+
* @param message 用户消息
127+
*/
128+
public void addUserMessage(SysDevice device, String message, String messageType) {
129+
// 更新缓存
130+
this.addMessage(device.getDeviceId(), device.getSessionId(), "user", message,
131+
device.getRoleId(), messageType, null);
132+
}
133+
134+
/**
135+
* 添加AI消息
136+
*
137+
* @param message AI消息
138+
*/
139+
public void addAssistantMessage(SysDevice device, String message, String messageType) {
140+
this.addMessage(device.getDeviceId(), device.getSessionId(), "assistant", message,
141+
device.getRoleId(), messageType, null);
142+
}
143+
144+
}

0 commit comments

Comments
 (0)