Skip to content

Commit 8968ad1

Browse files
committed
reafactor:重构整体逻辑为role配置
1 parent eae7269 commit 8968ad1

File tree

16 files changed

+138
-185
lines changed

16 files changed

+138
-185
lines changed

src/main/java/com/xiaozhi/communication/common/MessageHandler.java

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import com.xiaozhi.dialogue.tts.factory.TtsServiceFactory;
1414
import com.xiaozhi.entity.SysConfig;
1515
import com.xiaozhi.entity.SysDevice;
16+
import com.xiaozhi.entity.SysRole;
1617
import com.xiaozhi.enums.ListenState;
1718
import com.xiaozhi.service.SysDeviceService;
1819
import com.xiaozhi.service.SysRoleService;
@@ -73,7 +74,7 @@ public class MessageHandler {
7374
private ToolsGlobalRegistry toolsGlobalRegistry;
7475

7576
@Resource
76-
private SysRoleService sysRoleService;
77+
private SysRoleService roleService;
7778

7879
// 用于存储设备ID和验证码生成状态的映射
7980
private final Map<String, Boolean> captchaGenerationInProgress = new ConcurrentHashMap<>();
@@ -85,14 +86,13 @@ public class MessageHandler {
8586
* @param deviceIdAuth
8687
*/
8788
public void afterConnection(ChatSession chatSession, String deviceIdAuth) {
88-
final String deviceId = deviceIdAuth;
89-
final String sessionId = chatSession.getSessionId();
89+
String deviceId = deviceIdAuth;
90+
String sessionId = chatSession.getSessionId();
9091
// 注册会话
9192
sessionManager.registerSession(sessionId, chatSession);
9293

9394
logger.info("开始查询设备信息 - DeviceId: {}", deviceId);
94-
final SysDevice device = Optional.ofNullable(deviceService.selectDeviceById(deviceId)).orElse(new SysDevice());
95-
95+
SysDevice device = Optional.ofNullable(deviceService.selectDeviceById(deviceId)).orElse(new SysDevice());
9696
device.setDeviceId(deviceId);
9797
device.setSessionId(sessionId);
9898
sessionManager.registerDevice(sessionId, device);
@@ -105,27 +105,29 @@ public void afterConnection(ChatSession chatSession, String deviceIdAuth) {
105105
//以上同步处理结束后,再启动虚拟线程进行设备初始化,确保chatSession中已设置的sysDevice信息
106106
Thread.startVirtualThread(() -> {
107107
try {
108-
if (device.getSttId() != null) {
109-
SysConfig sttConfig = configManager.getConfig(device.getSttId());
108+
SysRole role = roleService.selectRoleById(device.getRoleId());
109+
110+
if (role.getSttId() != null) {
111+
SysConfig sttConfig = configManager.getConfig(role.getSttId());
110112
if (sttConfig != null) {
111113
sttFactory.getSttService(sttConfig);// 提前初始化,加速后续使用
112114
}
113115
}
114-
if (device.getTtsId() != null) {
115-
SysConfig ttsConfig = configManager.getConfig(device.getTtsId());
116-
if (ttsConfig != null) {// 设备查询从join config表修改为只查设备表,所以这里可能会有空值
117-
ttsFactory.getTtsService(ttsConfig, device.getVoiceName());// 提前初始化,加速后续使用
116+
if (role.getTtsId() != null) {
117+
SysConfig ttsConfig = configManager.getConfig(role.getTtsId());
118+
if (ttsConfig != null) {
119+
ttsFactory.getTtsService(ttsConfig, role.getVoiceName());// 提前初始化,加速后续使用
118120
}
119121
}
120-
if (device.getModelId() != null) {
121-
chatModelFactory.takeChatModel(device);// 提前初始化,加速后续使用
122+
if (role.getModelId() != null) {
123+
chatModelFactory.takeChatModel(chatSession);// 提前初始化,加速后续使用
122124
chatService.initializeHistory(chatSession);
123125
// 注册全局函数
124126
toolsSessionHolder.registerGlobalFunctionTools(chatSession);
125127
}
126128

127129
// 更新设备状态
128-
deviceService.updateNoRefreshCache(new SysDevice()
130+
deviceService.update(new SysDevice()
129131
.setDeviceId(device.getDeviceId())
130132
.setState(SysDevice.DEVICE_STATE_ONLINE)
131133
.setLastLogin(new Date().toString()));
@@ -157,7 +159,7 @@ public void afterConnectionClosed(String sessionId) {
157159
if (device != null) {
158160
Thread.startVirtualThread(() -> {
159161
try {
160-
deviceService.updateNoRefreshCache(new SysDevice()
162+
deviceService.update(new SysDevice()
161163
.setDeviceId(device.getDeviceId())
162164
.setState(SysDevice.DEVICE_STATE_OFFLINE)
163165
.setLastLogin(new Date().toString()));
@@ -213,7 +215,7 @@ public void handleUnboundDevice(String sessionId, SysDevice device) {
213215
try {
214216
// 设备已注册但未配置模型
215217
if (device.getDeviceName() != null && device.getRoleId() == null) {
216-
String message = "设备未配置对话模型,请到配置页面完成配置后开始对话";
218+
String message = "设备未配置角色,请到角色配置页面完成配置后开始对话";
217219

218220
String audioFilePath = ttsService.getDefaultTtsService().textToSpeech(message);
219221
audioService.sendAudioMessage(chatSession, new DialogueService.Sentence(message, audioFilePath), true,

src/main/java/com/xiaozhi/communication/server/websocket/WebSocketHandler.java

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import org.springframework.web.socket.CloseStatus;
1515
import org.springframework.web.socket.TextMessage;
1616
import org.springframework.web.socket.handler.AbstractWebSocketHandler;
17-
17+
import org.springframework.web.socket.WebSocketSession;
1818
import java.io.IOException;
1919
import java.net.URI;
2020
import java.util.HashMap;
@@ -38,7 +38,7 @@ public class WebSocketHandler extends AbstractWebSocketHandler {
3838
private DeviceMcpService deviceMcpService;
3939

4040
@Override
41-
public void afterConnectionEstablished(org.springframework.web.socket.WebSocketSession session) {
41+
public void afterConnectionEstablished(WebSocketSession session) {
4242
Map<String, String> headers = getHeadersFromSession(session);
4343
String deviceIdAuth = headers.get("device-id");
4444
String token = headers.get("Authorization");
@@ -68,17 +68,17 @@ public void afterConnectionEstablished(org.springframework.web.socket.WebSocketS
6868
// }
6969
// }else{
7070

71-
messageHandler.afterConnection(new WebSocketSession(session), deviceIdAuth);
71+
messageHandler.afterConnection(new com.xiaozhi.communication.server.websocket.WebSocketSession(session), deviceIdAuth);
7272
logger.info("WebSocket连接建立成功 - SessionId: {}, DeviceId: {}", session.getId(), deviceIdAuth);
7373

7474
}
7575

7676
@Override
77-
protected void handleTextMessage(org.springframework.web.socket.WebSocketSession session, TextMessage message) {
77+
protected void handleTextMessage(WebSocketSession session, TextMessage message) {
7878
String sessionId = session.getId();
79-
SysDevice device = sessionManager.getDeviceConfig(sessionId);
79+
String deviceId = sessionManager.getDeviceConfig(sessionId).getDeviceId();
80+
SysDevice device = deviceService.selectDeviceById(deviceId);
8081
String payload = message.getPayload();
81-
String deviceId = null;
8282
if (device == null) {
8383
deviceId = getHeadersFromSession(session).get("device-id");
8484
if (deviceId == null) {
@@ -108,7 +108,7 @@ protected void handleTextMessage(org.springframework.web.socket.WebSocketSession
108108
}
109109

110110
@Override
111-
protected void handleBinaryMessage(org.springframework.web.socket.WebSocketSession session, BinaryMessage message) {
111+
protected void handleBinaryMessage(WebSocketSession session, BinaryMessage message) {
112112
String sessionId = session.getId();
113113
SysDevice device = sessionManager.getDeviceConfig(sessionId);
114114
if (device == null) {
@@ -118,14 +118,14 @@ protected void handleBinaryMessage(org.springframework.web.socket.WebSocketSessi
118118
}
119119

120120
@Override
121-
public void afterConnectionClosed(org.springframework.web.socket.WebSocketSession session, CloseStatus status) {
121+
public void afterConnectionClosed(WebSocketSession session, CloseStatus status) {
122122
String sessionId = session.getId();
123123
messageHandler.afterConnectionClosed(sessionId);
124124
logger.info("WebSocket连接关闭 - SessionId: {}, 状态: {}", sessionId, status);
125125
}
126126

127127
@Override
128-
public void handleTransportError(org.springframework.web.socket.WebSocketSession session, Throwable exception) {
128+
public void handleTransportError(WebSocketSession session, Throwable exception) {
129129
String sessionId = session.getId();
130130
// 检查是否是客户端正常关闭连接导致的异常
131131
if (isClientCloseRequest(exception)) {
@@ -157,7 +157,7 @@ private boolean isClientCloseRequest(Throwable exception) {
157157
return false;
158158
}
159159

160-
private void handleHelloMessage(org.springframework.web.socket.WebSocketSession session, HelloMessage message) {
160+
private void handleHelloMessage(WebSocketSession session, HelloMessage message) {
161161
var sessionId = session.getId();
162162
logger.info("收到hello消息 - SessionId: {}, JsonNode: {}", sessionId, message);
163163

@@ -189,7 +189,7 @@ private void handleHelloMessage(org.springframework.web.socket.WebSocketSession
189189
}
190190
}
191191

192-
private Map<String, String> getHeadersFromSession(org.springframework.web.socket.WebSocketSession session) {
192+
private Map<String, String> getHeadersFromSession(WebSocketSession session) {
193193
// 尝试从请求头获取设备ID
194194
String[] deviceKeys = { "device-id", "mac_address", "uuid", "Authorization" };
195195

src/main/java/com/xiaozhi/controller/DeviceController.java

Lines changed: 19 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -68,30 +68,6 @@ public AjaxResult query(SysDevice device, HttpServletRequest request) {
6868
}
6969
}
7070

71-
/**
72-
* 设备信息更新
73-
*
74-
* @param device
75-
* @return
76-
*/
77-
@PostMapping("/update")
78-
@ResponseBody
79-
public AjaxResult update(SysDevice device) {
80-
try {
81-
device.setUserId(CmsUtils.getUserId());
82-
SysDevice persisteDevice = deviceService.update(device);
83-
if (persisteDevice != null) {
84-
deviceService.refreshSessionConfig(persisteDevice);
85-
return AjaxResult.success();
86-
} else {
87-
return AjaxResult.error();
88-
}
89-
} catch (Exception e) {
90-
logger.error(e.getMessage(), e);
91-
return AjaxResult.error();
92-
}
93-
}
94-
9571
/**
9672
* 添加设备
9773
*
@@ -133,6 +109,25 @@ public AjaxResult add(String code) {
133109
}
134110
}
135111

112+
/**
113+
* 设备信息更新
114+
*
115+
* @param device
116+
* @return
117+
*/
118+
@PostMapping("/update")
119+
@ResponseBody
120+
public AjaxResult update(SysDevice device) {
121+
try {
122+
device.setUserId(CmsUtils.getUserId());
123+
deviceService.update(device);
124+
return AjaxResult.success();
125+
} catch (Exception e) {
126+
logger.error(e.getMessage(), e);
127+
return AjaxResult.error();
128+
}
129+
}
130+
136131
/**
137132
* 删除设备
138133
*

src/main/java/com/xiaozhi/dialogue/llm/ChatService.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ public String chat(ChatSession session, String message, boolean useFunctionCall)
9090
SysDevice device = session.getSysDevice();
9191

9292
// 获取ChatModel
93-
ChatModel chatModel = chatModelFactory.takeChatModel(device);
93+
ChatModel chatModel = chatModelFactory.takeChatModel(session);
9494

9595
if (session.getChatMemory() == null) {// 如果记忆没初始化,则初始化一下
9696
initializeHistory(session);
@@ -155,7 +155,7 @@ public String chat(ChatSession session, String message, boolean useFunctionCall)
155155
public Flux<ChatResponse> chatStream(ChatSession session, SysDevice device, String message,
156156
boolean useFunctionCall) {
157157
// 获取ChatModel
158-
ChatModel chatModel = chatModelFactory.takeChatModel(device);
158+
ChatModel chatModel = chatModelFactory.takeChatModel(session);
159159

160160
ChatOptions chatOptions = ToolCallingChatOptions.builder()
161161
.toolCallbacks(useFunctionCall ? session.getToolCallbacks() : new ArrayList<>())
@@ -184,6 +184,7 @@ public void chatStreamBySentence(ChatSession session, String message, boolean us
184184
TriConsumer<String, Boolean, Boolean> sentenceHandler) {
185185
try {
186186
SysDevice device = session.getSysDevice();
187+
device.setSessionId(session.getSessionId());
187188
// 创建流式响应监听器
188189
StreamResponseListener streamListener = new TokenStreamResponseListener(session, message, sentenceHandler);
189190
final StringBuilder toolName = new StringBuilder(); // 当前句子的缓冲区

src/main/java/com/xiaozhi/dialogue/llm/factory/ChatModelFactory.java

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
package com.xiaozhi.dialogue.llm.factory;
22

3+
import com.xiaozhi.communication.common.ChatSession;
34
import com.xiaozhi.dialogue.llm.providers.CozeChatModel;
45
import com.xiaozhi.dialogue.llm.providers.DifyChatModel;
56
import com.xiaozhi.entity.SysConfig;
67
import com.xiaozhi.entity.SysDevice;
8+
import com.xiaozhi.entity.SysRole;
79
import com.xiaozhi.service.SysConfigService;
10+
import com.xiaozhi.service.SysRoleService;
811

912
import java.net.http.HttpClient;
1013
import java.time.Duration;
@@ -44,6 +47,8 @@ public class ChatModelFactory {
4447
@Autowired
4548
private SysConfigService configService;
4649
@Autowired
50+
private SysRoleService roleService;
51+
@Autowired
4752
private ToolCallingManager toolCallingManager;
4853
private final Logger logger = LoggerFactory.getLogger(ChatModelFactory.class);
4954

@@ -54,12 +59,14 @@ public class ChatModelFactory {
5459
* @param configId 配置ID,实际是模型ID。
5560
* @return
5661
*/
57-
public ChatModel takeChatModel(SysDevice device) {
58-
Integer modelId = device.getModelId();
62+
public ChatModel takeChatModel(ChatSession session) {
63+
SysDevice device = session.getSysDevice();
64+
SysRole role = roleService.selectRoleById(device.getRoleId());
65+
Integer modelId = role.getModelId();
5966
Assert.notNull(modelId, "配置ID不能为空");
6067
// 根据配置ID查询配置
6168
SysConfig config = configService.selectConfigById(modelId);
62-
return createChatModel(config, device);
69+
return createChatModel(config, role);
6370
}
6471

6572
/**
@@ -68,15 +75,15 @@ public ChatModel takeChatModel(SysDevice device) {
6875
* @param config
6976
* @return
7077
*/
71-
private ChatModel createChatModel(SysConfig config, SysDevice device) {
78+
private ChatModel createChatModel(SysConfig config, SysRole role) {
7279
String provider = config.getProvider().toLowerCase();
7380
String model = config.getConfigName();
7481
String endpoint = config.getApiUrl();
7582
String apiKey = config.getApiKey();
7683
String appId = config.getAppId();
7784
String apiSecret = config.getApiSecret();
78-
Double temperature = device.getTemperature();
79-
Double topP = device.getTopP();
85+
Double temperature = role.getTemperature();
86+
Double topP = role.getTopP();
8087
provider = provider.toLowerCase();
8188
switch (provider) {
8289
case "ollama":
@@ -96,14 +103,16 @@ private ChatModel createChatModel(SysConfig config, SysDevice device) {
96103
private ChatModel newOllamaChatModel(String endpoint, String appId, String apiKey, String apiSecret, String model, Double temperature, Double topP) {
97104
var ollamaApi = OllamaApi.builder().baseUrl(endpoint).build();
98105

106+
var ollamaAiChatOptions = OllamaOptions.builder()
107+
.model(model)
108+
.temperature(temperature)
109+
.topP(topP)
110+
.build();
111+
99112
var chatModel = OllamaChatModel.builder()
100113
.ollamaApi(ollamaApi)
101-
.defaultOptions(
102-
OllamaOptions.builder()
103-
.model(model)
104-
.temperature(temperature)
105-
.topP(topP)
106-
.build())
114+
.defaultOptions(ollamaAiChatOptions)
115+
.toolCallingManager(toolCallingManager)
107116
.build();
108117
logger.info("Using Ollama model: {}", model);
109118
return chatModel;
@@ -150,11 +159,13 @@ private ChatModel newOpenAiChatModel(String endpoint, String appId, String apiKe
150159
private ChatModel newZhipuChatModel(String endpoint, String appId, String apiKey, String apiSecret, String model, Double temperature, Double topP) {
151160
var zhiPuAiApi = new ZhiPuAiApi(endpoint, apiKey);
152161

153-
var chatModel = new ZhiPuAiChatModel(zhiPuAiApi, ZhiPuAiChatOptions.builder()
162+
var zhipuAiChatOptions = ZhiPuAiChatOptions.builder()
154163
.model(model)
155164
.temperature(temperature)
156165
.topP(topP)
157-
.build());
166+
.build();
167+
168+
var chatModel = new ZhiPuAiChatModel(zhiPuAiApi, zhipuAiChatOptions);
158169
logger.info("Using zhiPu model: {}", model);
159170
return chatModel;
160171
}

src/main/java/com/xiaozhi/dialogue/llm/tool/ToolsGlobalRegistry.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ public ToolCallback resolve(@NotNull String toolName) {
3838
*/
3939
public ToolCallback registerFunction(String name, ToolCallback functionCallTool) {
4040
ToolCallback result = allFunction.putIfAbsent(name, functionCallTool);
41-
logger.info("[{}] Function:{} registered into global successfully", TAG, name);
41+
logger.debug("[{}] Function:{} registered into global successfully", TAG, name);
4242
return result;
4343
}
4444

src/main/java/com/xiaozhi/dialogue/llm/tool/function/ChangeRoleFunction.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,6 @@ public ToolCallback getFunctionCallTool(ChatSession chatSession) {
5353
if(role_id.isPresent()){
5454
sysDevice.setRoleId(role_id.get());//测试,固定角色
5555
sysDeviceService.update(sysDevice);
56-
sysDeviceService.refreshSessionConfig(sysDevice);
57-
chatSession.clearMemory();
5856
return "角色已切换至" + roleName;
5957
}else{
6058
return "角色切换失败, 没有对应角色哦";

src/main/java/com/xiaozhi/dialogue/service/AudioService.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ public CompletableFuture<Void> sendAudioMessage(
222222
@Override
223223
public void run() {
224224
try {
225-
if (!finalPlayingState.get() || frameIndex[0] >= opusFrames.size()) {
225+
if (!finalPlayingState.get() || frameIndex[0] >= opusFrames.size() || !session.isOpen()) {
226226
// 取消调度任务
227227
cancelScheduledTask(sessionId);
228228

0 commit comments

Comments
 (0)