Skip to content

Commit 32395b9

Browse files
committed
2 parents 0ab9480 + 65ca6dd commit 32395b9

File tree

11 files changed

+150
-96
lines changed

11 files changed

+150
-96
lines changed

src/main/java/com/xiaozhi/communication/common/MessageHandler.java

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
import com.xiaozhi.communication.domain.*;
44
import com.xiaozhi.dialogue.llm.factory.ChatModelFactory;
5-
import com.xiaozhi.dialogue.llm.memory.ChatMemory;
65
import com.xiaozhi.dialogue.llm.memory.Conversation;
6+
import com.xiaozhi.dialogue.llm.memory.ConversationFactory;
77
import com.xiaozhi.dialogue.llm.tool.ToolsGlobalRegistry;
88
import com.xiaozhi.dialogue.llm.tool.ToolsSessionHolder;
99
import com.xiaozhi.dialogue.service.AudioService;
@@ -67,7 +67,7 @@ public class MessageHandler {
6767
private SttServiceFactory sttFactory;
6868

6969
@Autowired
70-
private ChatMemory chatMemory;
70+
private ConversationFactory conversationFactory;
7171

7272
@Resource
7373
private ChatModelFactory chatModelFactory;
@@ -124,7 +124,7 @@ public void afterConnection(ChatSession chatSession, String deviceIdAuth) {
124124
}
125125
if (role.getModelId() != null) {
126126
chatModelFactory.takeChatModel(chatSession);// 提前初始化,加速后续使用
127-
Conversation conversation = chatMemory.initConversation(device, role, sessionId);
127+
Conversation conversation = conversationFactory.initConversation(device, role, sessionId);
128128
chatSession.setConversation(conversation);
129129
// 注册全局函数
130130
toolsSessionHolder.registerGlobalFunctionTools(chatSession);
@@ -181,8 +181,11 @@ public void afterConnectionClosed(String sessionId) {
181181
audioService.cleanupSession(sessionId);
182182
// 清理对话
183183
dialogueService.cleanupSession(sessionId);
184-
// 清理ChatService缓存的对话历史。
185-
chatMemory.clearMessages(device.getDeviceId());
184+
// 清理Conversation缓存的对话历史。
185+
Conversation conversation = chatSession.getConversation();
186+
if (conversation != null) {
187+
conversation.clear();
188+
}
186189
}
187190

188191
/**

src/main/java/com/xiaozhi/dialogue/llm/factory/ChatModelFactory.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ public class ChatModelFactory {
5959
* 根据配置ID创建ChatModel,首次调用时缓存,缓存key为配置ID。
6060
*
6161
* @see SysConfigService#selectConfigById(Integer) 已经进行了Cacheable,所以此处没有必要缓存
62-
* @param configId 配置ID,实际是模型ID。
62+
* @param session 与网络链接绑定的聊天会话
6363
* @return
6464
*/
6565
public ChatModel takeChatModel(ChatSession session) {

src/main/java/com/xiaozhi/dialogue/llm/memory/ChatMemory.java

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
package com.xiaozhi.dialogue.llm.memory;
22

3-
import com.xiaozhi.entity.SysDevice;
43
import com.xiaozhi.entity.SysMessage;
5-
import com.xiaozhi.entity.SysRole;
64

75
import java.util.List;
86

@@ -17,18 +15,20 @@
1715
public interface ChatMemory {
1816

1917
/**
20-
* 不同的ChatMemory实现类,可以有不同的处理策略,可以初始化不同的Conversation子类。
21-
*
22-
* @param device 设备
23-
* @param role 角色
24-
* @param sessionId 会话ID
25-
* @return 会话
18+
* 添加消息
19+
* TODO 参数太多,后续考虑如何简化一些
2620
*/
27-
Conversation initConversation(SysDevice device, SysRole role, String sessionId);
28-
29-
// TODO 考虑将参数以Message对象传递,而不是多个参数。在再下一层Service层转换sparing ai的Message为Mapper需要的SysMessage对象
3021
void addMessage(String deviceId, String sessionId, String sender, String content, Integer roleId, String messageType, Long timeMillis);
3122

23+
/**
24+
* 获取历史对话消息列表
25+
* TODO messageType参数,后续考虑是否需要。另外可重构为一个枚举类
26+
*
27+
* @param deviceId 设备ID
28+
* @param messageType 消息类型
29+
* @param limit 限制数量
30+
* @return 消息列表
31+
*/
3232
List<SysMessage> getMessages(String deviceId, String messageType, Integer limit);
3333

3434
/**

src/main/java/com/xiaozhi/dialogue/llm/memory/Conversation.java

Lines changed: 54 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,19 @@
11
package com.xiaozhi.dialogue.llm.memory;
22

33
import com.xiaozhi.entity.SysDevice;
4+
import com.xiaozhi.entity.SysMessage;
45
import com.xiaozhi.entity.SysRole;
56
import org.springframework.ai.chat.messages.AssistantMessage;
67
import org.springframework.ai.chat.messages.Message;
8+
import org.springframework.ai.chat.messages.MessageType;
79
import org.springframework.ai.chat.messages.UserMessage;
10+
import org.springframework.util.Assert;
811

912
import java.util.ArrayList;
13+
import java.util.Collections;
1014
import java.util.List;
15+
import java.util.Map;
16+
import java.util.stream.Collectors;
1117

1218
/**
1319
* Conversation 是一个 对应于 sys_message 表的,但高于 sys_message 的一个抽象实体。
@@ -18,18 +24,23 @@
1824
* 只有sessionID是真正挂在Conversation的属性。
1925
*
2026
*/
21-
public abstract class Conversation {
27+
public class Conversation {
2228
private final SysDevice device;
2329
private final SysRole role;
2430
private final String sessionId;
2531

26-
private List<Message> messages;
32+
protected List<Message> messages = new ArrayList<>();
2733

28-
public Conversation(SysDevice device, SysRole role, String sessionId, List<Message> messages) {
34+
public Conversation(SysDevice device, SysRole role, String sessionId) {
35+
// final 属性的规范要求
36+
Assert.notNull(device, "device must not be null");
37+
Assert.notNull(role, "role must not be null");
38+
Assert.notNull(device.getDeviceId(), "deviceId must not be null");
39+
Assert.notNull(role.getRoleId(), "roleId must not be null");
40+
Assert.notNull(sessionId, "sessionId must not be null");
2941
this.device = device;
3042
this.role = role;
3143
this.sessionId = sessionId;
32-
this.messages = new ArrayList<>(messages);
3344
}
3445

3546
public SysDevice device() {
@@ -47,15 +58,51 @@ public List<Message> messages() {
4758
return messages;
4859
}
4960

50-
abstract public void clear();
61+
public void clear(){
62+
messages.clear();
63+
}
5164

52-
abstract public void addMessage(UserMessage userMessage, Long userTimeMillis,AssistantMessage assistantMessage, Long assistantTimeMillis);
65+
public void addMessage(UserMessage userMessage, Long userTimeMillis,AssistantMessage assistantMessage, Long assistantTimeMillis){
66+
messages.add(userMessage);
67+
messages.add(assistantMessage);
68+
}
5369

5470
/**
5571
* 获取适用于放入prompt提示词的多轮消息列表。
5672
* userMessage 不会因调用此方法而入库(或进入记忆)
5773
* @param userMessage 必须且不为空。
74+
* @return 新的消息列表对象,避免污染原有的列表。
75+
*/
76+
public List<Message> prompt(UserMessage userMessage){
77+
List<Message> newMessages = new ArrayList<>();
78+
newMessages.addAll(this.messages);
79+
newMessages.add(userMessage);
80+
return newMessages;
81+
}
82+
83+
/**
84+
* 将数据库记录的SysMessag转换为spring-ai的Message。
85+
*
86+
* @param messages
5887
* @return
5988
*/
60-
abstract public List<Message> prompt(UserMessage userMessage);
89+
public static List<Message> convert(List<SysMessage> messages) {
90+
if (messages == null || messages.isEmpty()) {
91+
return Collections.emptyList();
92+
}
93+
return messages.stream()
94+
.filter(message -> MessageType.ASSISTANT.getValue().equals(message.getSender())
95+
|| MessageType.USER.getValue().equals(message.getSender()))
96+
.map(message -> {
97+
String role = message.getSender();
98+
// 一般消息("messageType", "NORMAL");//默认为普通消息
99+
Map<String, Object> metadata = Map.of("messageId", message.getMessageId(), "messageType",
100+
message.getMessageType());
101+
return switch (role) {
102+
case "assistant" -> new AssistantMessage(message.getMessage(), metadata);
103+
case "user" -> UserMessage.builder().text(message.getMessage()).metadata(metadata).build();
104+
default -> throw new IllegalArgumentException("Invalid role: " + role);
105+
};
106+
}).collect(Collectors.toList());
107+
}
61108
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package com.xiaozhi.dialogue.llm.memory;
2+
3+
import com.xiaozhi.entity.SysDevice;
4+
import com.xiaozhi.entity.SysRole;
5+
6+
public interface ConversationFactory {
7+
/**
8+
* 不同的ChatMemory实现类,可以有不同的处理策略,可以初始化不同的Conversation子类。
9+
*
10+
* @param device 设备
11+
* @param role 角色
12+
* @param sessionId 会话ID
13+
* @return 会话
14+
*/
15+
Conversation initConversation(SysDevice device, SysRole role, String sessionId);
16+
}

src/main/java/com/xiaozhi/dialogue/llm/memory/DatabaseChatMemory.java

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,7 @@
22

33
import com.xiaozhi.common.web.PageFilter;
44
import com.xiaozhi.entity.Base;
5-
import com.xiaozhi.entity.SysDevice;
65
import com.xiaozhi.entity.SysMessage;
7-
import com.xiaozhi.entity.SysRole;
86
import com.xiaozhi.service.SysMessageService;
97

108
import org.slf4j.Logger;
@@ -19,28 +17,20 @@
1917
import java.util.Date;
2018
import java.util.List;
2119

22-
import static com.xiaozhi.dialogue.llm.memory.MessageWindowConversation.DEFAULT_HISTORY_LIMIT;
23-
2420
/**
2521
* 基于数据库的聊天记忆实现
26-
* 全局单例类,负责Conversatin的初始化、保存、清理。
22+
* 全局单例类,负责Conversatin里消息的获取、保存、清理。
23+
* 后续考虑:DatabaseChatMemory 是对 SysMessageService 的一层薄封装,未来或者有可能考虑合并这两者。
2724
*/
2825
@Service
2926
public class DatabaseChatMemory implements ChatMemory {
3027
private static final Logger logger = LoggerFactory.getLogger(DatabaseChatMemory.class);
3128

32-
@Autowired
33-
private SysMessageService messageService;
29+
private final SysMessageService messageService;
3430

35-
@Override
36-
public Conversation initConversation(SysDevice device, SysRole role, String sessionId) {
37-
Conversation conversation = MessageWindowConversation.builder().chatMemory(this)
38-
.maxMessages(DEFAULT_HISTORY_LIMIT)
39-
.role(role)
40-
.device(device)
41-
.sessionId(sessionId)
42-
.build();
43-
return conversation;
31+
@Autowired
32+
public DatabaseChatMemory(SysMessageService messageService) {
33+
this.messageService = messageService;
4434
}
4535

4636
@Override

src/main/java/com/xiaozhi/dialogue/llm/memory/MessageWindowConversation.java

Lines changed: 11 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -4,38 +4,37 @@
44
import com.xiaozhi.entity.SysMessage;
55
import com.xiaozhi.entity.SysRole;
66
import org.springframework.ai.chat.messages.*;
7-
import org.springframework.beans.factory.annotation.Autowired;
87

9-
import org.springframework.util.Assert;
108
import org.springframework.util.StringUtils;
119

1210
import java.util.*;
1311

14-
import java.util.stream.Collectors;
15-
1612
/**
1713
* 限定消息条数(消息窗口)的Conversation实现。根据不同的策略,可实现聊天会话的持久化、加载、清除等功能。
1814
*/
1915
public class MessageWindowConversation extends Conversation {
2016
// 历史记录默认限制数量
2117
public static final int DEFAULT_HISTORY_LIMIT = 10;
22-
private final DatabaseChatMemory chatMemory;
18+
private final ChatMemory chatMemory;
2319
private final int maxMessages;
2420
private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(MessageWindowConversation.class);
2521

2622

27-
public MessageWindowConversation(DatabaseChatMemory chatMemory, SysDevice device, SysRole role, int maxMessages, List<Message> messages){
28-
super(device, role, device.getSessionId(), messages);
29-
this.chatMemory = chatMemory;
23+
public MessageWindowConversation(SysDevice device, SysRole role, String sessionId, int maxMessages, ChatMemory chatMemory){
24+
super(device, role, sessionId);
3025
this.maxMessages = maxMessages;
26+
this.chatMemory = chatMemory;
27+
logger.info("加载设备{}的普通消息(SysMessage.MESSAGE_TYPE_NORMAL)作为对话历史",device.getDeviceId());
28+
List<SysMessage> history = chatMemory.getMessages(device.getDeviceId(), SysMessage.MESSAGE_TYPE_NORMAL, maxMessages);
29+
super.messages.addAll(convert(history)) ;
3130
}
3231

3332
public static class Builder {
3433
private SysDevice device;
3534
private SysRole role;
3635
private String sessionId;
3736
private int maxMessages;
38-
private DatabaseChatMemory chatMemory;
37+
private ChatMemory chatMemory;
3938

4039
public Builder device(SysDevice device) {
4140
this.device = device;
@@ -51,7 +50,7 @@ public Builder sessionId(String sessionId) {
5150
return this;
5251
}
5352

54-
public Builder chatMemory(DatabaseChatMemory chatMemory) {
53+
public Builder chatMemory(ChatMemory chatMemory) {
5554
this.chatMemory = chatMemory;
5655
return this;
5756
}
@@ -62,17 +61,7 @@ public Builder maxMessages(int maxMessages) {
6261
}
6362

6463
public MessageWindowConversation build(){
65-
Assert.notNull(device, "device must not be null");
66-
Assert.notNull(role, "role must not be null");
67-
String deviceId = device.getDeviceId();
68-
Assert.notNull(deviceId, "deviceId must not be null");
69-
Assert.notNull(role.getRoleId(), "roleId must not be null");
70-
Assert.notNull(sessionId, "sessionId must not be null");
71-
logger.info("获取设备{}的历史消息",deviceId);
72-
List<SysMessage> history = chatMemory.getMessages(deviceId, SysMessage.MESSAGE_TYPE_NORMAL, maxMessages);
73-
List<Message> messages =convert(history);
74-
75-
return new MessageWindowConversation(chatMemory,device,role,maxMessages,messages);
64+
return new MessageWindowConversation(device,role,sessionId,maxMessages,chatMemory);
7665
}
7766
}
7867

@@ -132,6 +121,7 @@ public void addMessage(UserMessage userMessage, Long userTimeMillis, AssistantM
132121
}
133122
}
134123

124+
@Override
135125
public List<Message> prompt(UserMessage userMessage) {
136126
String roleDesc = role().getRoleDesc();
137127
SystemMessage systemMessage = new SystemMessage(StringUtils.hasText(roleDesc)?roleDesc:"");
@@ -149,31 +139,4 @@ public List<Message> prompt(UserMessage userMessage) {
149139
return messages;
150140
}
151141

152-
/**
153-
* 将数据库记录的SysMessag转换为spring-ai的Message。
154-
* 加载的历史都是普通消息(SysMessage.MESSAGE_TYPE_NORMAL)
155-
*
156-
* @param messages
157-
* @return
158-
*/
159-
public static List<Message> convert(List<SysMessage> messages) {
160-
if (messages == null || messages.isEmpty()) {
161-
return Collections.emptyList();
162-
}
163-
return messages.stream()
164-
.filter(message -> MessageType.ASSISTANT.getValue().equals(message.getSender())
165-
|| MessageType.USER.getValue().equals(message.getSender()))
166-
.map(message -> {
167-
String role = message.getSender();
168-
// 一般消息("messageType", "NORMAL");//默认为普通消息
169-
Map<String, Object> metadata = Map.of("messageId", message.getMessageId(), "messageType",
170-
message.getMessageType());
171-
return switch (role) {
172-
case "assistant" -> new AssistantMessage(message.getMessage(), metadata);
173-
case "user" -> UserMessage.builder().text(message.getMessage()).metadata(metadata).build();
174-
default -> throw new IllegalArgumentException("Invalid role: " + role);
175-
};
176-
}).collect(Collectors.toList());
177-
}
178-
179142
}

0 commit comments

Comments
 (0)