Skip to content

Commit 820fb57

Browse files
authored
Merge pull request #104 from vritser/main
优化 STT 相关
2 parents 8919c0a + 11d288d commit 820fb57

File tree

2 files changed

+67
-257
lines changed

2 files changed

+67
-257
lines changed

src/main/java/com/xiaozhi/dialogue/stt/factory/SttServiceFactory.java

Lines changed: 35 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44
import com.xiaozhi.dialogue.stt.providers.*;
55
import com.xiaozhi.entity.SysConfig;
66

7+
import jakarta.annotation.Nonnull;
78
import org.slf4j.Logger;
89
import org.slf4j.LoggerFactory;
910
import org.springframework.stereotype.Component;
1011

1112
import jakarta.annotation.PostConstruct;
13+
1214
import java.util.Map;
1315
import java.util.concurrent.ConcurrentHashMap;
1416

@@ -17,7 +19,7 @@ public class SttServiceFactory {
1719

1820
private static final Logger logger = LoggerFactory.getLogger(SttServiceFactory.class);
1921

20-
// 缓存已初始化的服务:对于API服务,键为"provider:configId"格式;对于本地服务,键为provider名称
22+
// 缓存已初始化的服务:key format: "provider:configId"
2123
private final Map<String, SttService> serviceCache = new ConcurrentHashMap<>();
2224

2325
// 默认服务提供商名称
@@ -46,128 +48,81 @@ public void initializeDefaultSttService() {
4648
/**
4749
* 初始化Vosk服务
4850
*/
49-
private synchronized void initializeVosk() {
51+
private synchronized SttService initializeVosk() {
5052
if (serviceCache.containsKey(DEFAULT_PROVIDER)) {
51-
return;
53+
return serviceCache.get(DEFAULT_PROVIDER);
5254
}
5355

5456
try {
55-
VoskSttService voskService = new VoskSttService();
57+
var voskService = new VoskSttService();
5658
voskService.initialize();
5759
serviceCache.put(DEFAULT_PROVIDER, voskService);
5860
voskInitialized = true;
5961
logger.info("Vosk STT服务初始化成功");
62+
return voskService;
6063
} catch (Exception e) {
6164
logger.error("Vosk STT服务初始化失败", e);
6265
voskInitialized = false;
6366
}
67+
return null;
6468
}
6569

6670
/**
6771
* 获取默认STT服务
68-
* 如果Vosk可用则返回Vosk,否则返回备选服务
6972
*/
7073
public SttService getDefaultSttService() {
71-
// 如果Vosk已初始化成功,直接返回
72-
if (voskInitialized && serviceCache.containsKey(DEFAULT_PROVIDER)) {
73-
return serviceCache.get(DEFAULT_PROVIDER);
74-
}
75-
76-
// 否则返回备选服务
77-
if (fallbackProvider != null && serviceCache.containsKey(fallbackProvider)) {
78-
return serviceCache.get(fallbackProvider);
79-
}
80-
81-
// 如果没有备选服务,尝试创建一个API类型的服务作为备选
82-
if (serviceCache.isEmpty()) {
83-
logger.warn("没有可用的STT服务,将尝试创建默认API服务");
84-
try {
85-
return null;
86-
} catch (Exception e) {
87-
logger.error("创建默认API服务失败", e);
88-
return null;
89-
}
90-
}
91-
92-
return null;
74+
return getSttService(null);
9375
}
9476

9577
/**
9678
* 根据配置获取STT服务
9779
*/
9880
public SttService getSttService(SysConfig config) {
9981
if (config == null) {
100-
return getDefaultSttService();
101-
}
102-
103-
String provider = config.getProvider();
104-
105-
// 如果是Vosk,直接使用全局共享的实例
106-
if (DEFAULT_PROVIDER.equals(provider)) {
107-
// 如果Vosk还未初始化,尝试初始化
108-
if (!voskInitialized && !serviceCache.containsKey(DEFAULT_PROVIDER)) {
109-
initializeVosk();
110-
}
111-
112-
// Vosk初始化失败的情况
113-
if (!voskInitialized) {
114-
return null;
115-
}
116-
return serviceCache.get(DEFAULT_PROVIDER);
82+
config = new SysConfig().setProvider(DEFAULT_PROVIDER).setConfigId(-1);
11783
}
11884

11985
// 对于API服务,使用"provider:configId"作为缓存键,确保每个配置使用独立的服务实例
120-
Integer configId = config.getConfigId();
121-
String cacheKey = provider + ":" + (configId != null ? configId : "default");
86+
var cacheKey = config.getProvider() + ":" + config.getConfigId();
12287

12388
// 检查是否已有该配置的服务实例
12489
if (serviceCache.containsKey(cacheKey)) {
12590
return serviceCache.get(cacheKey);
12691
}
12792

12893
// 创建新的API服务实例
129-
try {
130-
SttService service = createApiService(config);
131-
if (service != null) {
132-
serviceCache.put(cacheKey, service);
94+
var service = createApiService(config);
95+
serviceCache.put(cacheKey, service);
13396

134-
// 如果没有备选默认服务,将此服务设为备选
135-
if (fallbackProvider == null) {
136-
fallbackProvider = cacheKey;
137-
}
138-
return service;
139-
}
140-
} catch (Exception e) {
141-
logger.error("创建{}服务失败, configId={}", provider, configId, e);
97+
// 如果没有备选默认服务,将此服务设为备选
98+
if (fallbackProvider == null) {
99+
fallbackProvider = cacheKey;
142100
}
143101

144-
return null;
102+
return service;
145103
}
146104

147105
/**
148106
* 根据配置创建API类型的STT服务
149107
*/
150-
private SttService createApiService(SysConfig config) {
151-
if (config == null) {
152-
return null;
153-
}
154-
155-
String provider = config.getProvider();
156-
157-
// 根据提供商类型创建对应的服务实例
158-
if ("tencent".equals(provider)) {
159-
return new TencentSttService(config);
160-
} else if ("aliyun".equals(provider)) {
161-
return new AliyunSttService(config);
162-
} else if ("funasr".equals(provider)) {
163-
return new FunASRSttService(config);
164-
} else if ("xfyun".equals(provider)) {
165-
return new XfyunSttService(config);
166-
}
167-
// 可以添加其他服务提供商的支持
168-
169-
logger.warn("不支持的STT服务提供商: {}", provider);
170-
return null;
108+
private SttService createApiService(@Nonnull SysConfig config) {
109+
return switch (config.getProvider()) {
110+
case "tencent" -> new TencentSttService(config);
111+
case "aliyun" -> new AliyunSttService(config);
112+
case "funasr" -> new FunASRSttService(config);
113+
case "xfyun" -> new XfyunSttService(config);
114+
default -> {
115+
var service = initializeVosk();
116+
if (service == null) {
117+
// If vosk create failed, return fallback stt service
118+
if (fallbackProvider != null && serviceCache.containsKey(fallbackProvider)) {
119+
yield serviceCache.get(fallbackProvider);
120+
}
121+
throw new RuntimeException("Create vosk service failed");
122+
}
123+
yield service;
124+
}
125+
};
171126
}
172127

173128
public void removeCache(SysConfig config) {

0 commit comments

Comments
 (0)