Skip to content

Commit 9cd5485

Browse files
committed
refactor:重构coze使用OAuth认证
1 parent 87891a6 commit 9cd5485

File tree

9 files changed

+435
-158
lines changed

9 files changed

+435
-158
lines changed

src/main/java/com/xiaozhi/dialogue/llm/factory/ChatModelFactory.java

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import com.xiaozhi.communication.common.ChatSession;
44
import com.xiaozhi.dialogue.llm.providers.CozeChatModel;
55
import com.xiaozhi.dialogue.llm.providers.DifyChatModel;
6+
import com.xiaozhi.dialogue.token.factory.TokenServiceFactory;
67
import com.xiaozhi.entity.SysConfig;
78
import com.xiaozhi.entity.SysDevice;
89
import com.xiaozhi.entity.SysRole;
@@ -50,6 +51,8 @@ public class ChatModelFactory {
5051
private SysRoleService roleService;
5152
@Autowired
5253
private ToolCallingManager toolCallingManager;
54+
@Autowired
55+
private TokenServiceFactory tokenService;
5356
private final Logger logger = LoggerFactory.getLogger(ChatModelFactory.class);
5457

5558
/**
@@ -85,15 +88,21 @@ private ChatModel createChatModel(SysConfig config, SysRole role) {
8588
Double temperature = role.getTemperature();
8689
Double topP = role.getTopP();
8790
provider = provider.toLowerCase();
91+
// Coze和Dify 拥有全局唯一配置,所以需要查询唯一配置信息来作为模型的 Token 获取
92+
SysConfig agentConfig = new SysConfig().setConfigType("agent").setUserId(config.getUserId());
93+
SysConfig queryConfig;
8894
switch (provider) {
8995
case "ollama":
9096
return newOllamaChatModel(endpoint, appId, apiKey, apiSecret, model, temperature, topP);
9197
case "zhipu":
9298
return newZhipuChatModel(endpoint, appId, apiKey, apiSecret, model, temperature, topP);
9399
case "dify":
94-
return new DifyChatModel(endpoint, appId, apiKey, apiSecret, model);
100+
queryConfig = configService.query(agentConfig.setProvider("dify"), null).get(0);
101+
return new DifyChatModel(endpoint, queryConfig.getApiKey());
95102
case "coze":
96-
return new CozeChatModel(endpoint, appId, apiKey, apiSecret, model);
103+
queryConfig = configService.query(agentConfig.setProvider("coze"), null).get(0);
104+
String token = tokenService.getTokenService(queryConfig).getToken();
105+
return new CozeChatModel(token, model);
97106
// 默认为 openai 协议
98107
default:
99108
return newOpenAiChatModel(endpoint, appId, apiKey, apiSecret, model, temperature, topP);

src/main/java/com/xiaozhi/dialogue/llm/providers/CozeChatModel.java

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
import com.coze.openapi.client.connversations.message.model.MessageType;
88
import com.coze.openapi.service.auth.TokenAuth;
99
import com.coze.openapi.service.service.CozeAPI;
10+
1011
import io.reactivex.Flowable;
12+
1113
import org.slf4j.Logger;
1214
import org.slf4j.LoggerFactory;
1315
import org.springframework.ai.chat.messages.AssistantMessage;
@@ -32,29 +34,16 @@ public class CozeChatModel implements ChatModel {
3234
private final CozeAPI coze;
3335
private final String botId;
3436

35-
private final String endpoint;
36-
private final String apiKey;
37-
private final String model;
38-
private final String appId;
39-
private final String apiSecret;
4037
private final Logger logger = LoggerFactory.getLogger(getClass());
4138
public static final String PROVIDER_NAME = "coze";
4239

40+
4341
/**
4442
* 构造函数
4543
*
46-
* @param endpoint API端点
47-
* @param appId 应用ID (在Coze中对应botId)
48-
* @param apiKey API密钥 (在Coze中不使用)
49-
* @param apiSecret API密钥 (在Coze中对应access_token)
5044
* @param model 模型名称 (在Coze中不使用)
5145
*/
52-
public CozeChatModel(String endpoint, String appId, String apiKey, String apiSecret, String model) {
53-
this.endpoint = endpoint;
54-
this.appId = appId;
55-
this.apiSecret = apiSecret;
56-
this.apiKey = apiKey;
57-
this.model = model;
46+
public CozeChatModel(String apiSecret, String model) {
5847

5948
// 使用apiSecret作为access_token
6049
this.authCli = new TokenAuth(apiSecret);

src/main/java/com/xiaozhi/dialogue/llm/providers/DifyChatModel.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ public class DifyChatModel implements ChatModel {
3333
* @param apiSecret
3434
* @param model 模型名称
3535
*/
36-
public DifyChatModel(String endpoint, String appId, String apiKey, String apiSecret, String model) {
36+
public DifyChatModel(String endpoint, String apiKey) {
3737
chatClient = DifyClientFactory.createChatClient(endpoint, apiKey);
3838
}
3939

src/main/java/com/xiaozhi/dialogue/token/factory/TokenServiceFactory.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import com.xiaozhi.dialogue.token.TokenService;
88
import com.xiaozhi.dialogue.token.providers.AliyunTokenService;
9+
import com.xiaozhi.dialogue.token.providers.CozeTokenService;
910
import com.xiaozhi.entity.SysConfig;
1011

1112
import jakarta.annotation.PostConstruct;
@@ -86,7 +87,7 @@ public TokenService getTokenService(SysConfig config) {
8687
private TokenService createTokenService(SysConfig config) {
8788
return switch (config.getProvider()) {
8889
case "aliyun" -> new AliyunTokenService(config);
89-
default -> new AliyunTokenService(config);
90+
default -> new CozeTokenService(config);
9091
};
9192
}
9293

src/main/java/com/xiaozhi/dialogue/token/providers/AliyunTokenService.java

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -171,10 +171,4 @@ public boolean needsCacheCleanup() {
171171
return tokenCache != null && tokenCache.needsCacheCleanup();
172172
}
173173

174-
/**
175-
* 获取配置ID
176-
*/
177-
public Integer getConfigId() {
178-
return configId;
179-
}
180174
}
Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
package com.xiaozhi.dialogue.token.providers;
2+
3+
import org.slf4j.Logger;
4+
import org.slf4j.LoggerFactory;
5+
import org.springframework.http.*;
6+
import org.springframework.web.client.RestTemplate;
7+
8+
import com.alibaba.fastjson.JSON;
9+
import com.alibaba.fastjson.JSONObject;
10+
import com.xiaozhi.dialogue.token.TokenService;
11+
import com.xiaozhi.dialogue.token.entity.TokenCache;
12+
import com.xiaozhi.entity.SysConfig;
13+
14+
import io.jsonwebtoken.Jwts;
15+
import io.jsonwebtoken.SignatureAlgorithm;
16+
17+
import java.security.KeyFactory;
18+
import java.security.PrivateKey;
19+
import java.security.spec.PKCS8EncodedKeySpec;
20+
import java.time.LocalDateTime;
21+
import java.util.Base64;
22+
import java.util.HashMap;
23+
import java.util.Map;
24+
import java.util.UUID;
25+
import java.util.concurrent.locks.ReentrantLock;
26+
27+
public class CozeTokenService implements TokenService {
28+
29+
private static final Logger logger = LoggerFactory.getLogger(CozeTokenService.class);
30+
31+
private static final String PROVIDER_NAME = "coze";
32+
33+
private final String oauthAppId; // OAuth应用ID
34+
private final String publicKey; // 公钥
35+
private final String privateKey; // 私钥
36+
private final Integer configId;
37+
38+
// Token缓存
39+
private volatile TokenCache tokenCache;
40+
// 防止并发刷新的锁
41+
private final ReentrantLock refreshLock = new ReentrantLock();
42+
43+
// Coze API配置
44+
private static final String COZE_API_ENDPOINT = "api.coze.cn";
45+
private static final String TOKEN_URL = "https://api.coze.cn/api/permission/oauth2/token";
46+
private static final String ALGORITHM = "RS256";
47+
private static final String TOKEN_TYPE = "JWT";
48+
private static final int DEFAULT_DURATION_SECONDS = 86399; // 24小时 - 1秒
49+
50+
private final RestTemplate restTemplate;
51+
52+
public CozeTokenService(SysConfig config) {
53+
this.oauthAppId = config.getAppId(); // OAuth应用ID
54+
this.publicKey = config.getAk(); // 公钥
55+
this.privateKey = config.getSk(); // 私钥
56+
this.configId = config.getConfigId();
57+
58+
this.restTemplate = new RestTemplate();
59+
}
60+
61+
@Override
62+
public String getProviderName() {
63+
return PROVIDER_NAME;
64+
}
65+
66+
@Override
67+
public String getToken() {
68+
// 检查缓存是否存在且有效
69+
if (tokenCache != null) {
70+
// 更新最后使用时间
71+
tokenCache.updateLastUsedTime();
72+
73+
// 如果token已过期,清除缓存
74+
if (tokenCache.isExpired()) {
75+
clearTokenCache();
76+
}
77+
// 如果需要刷新(剩余1小时),异步刷新
78+
else if (tokenCache.needsRefresh()) {
79+
refreshTokenAsync();
80+
return tokenCache.getToken(); // 返回当前还有效的token
81+
}
82+
// 如果token仍然有效,直接返回
83+
else if (isTokenValid()) {
84+
return tokenCache.getToken();
85+
}
86+
}
87+
88+
// 缓存无效或不存在,获取新token
89+
return refreshToken();
90+
}
91+
92+
@Override
93+
public String refreshToken() {
94+
refreshLock.lock();
95+
try {
96+
// 双重检查,防止重复刷新
97+
if (tokenCache != null && isTokenValid() && !tokenCache.needsRefresh()) {
98+
return tokenCache.getToken();
99+
}
100+
101+
// 1. 生成JWT
102+
String jwt = generateJWT();
103+
104+
// 2. 使用JWT获取访问令牌
105+
String accessToken = requestAccessToken(jwt);
106+
107+
// 3. 计算过期时间(默认24小时)
108+
LocalDateTime expireTime = LocalDateTime.now().plusSeconds(DEFAULT_DURATION_SECONDS);
109+
110+
// 4. 更新缓存
111+
tokenCache = new TokenCache(accessToken, expireTime);
112+
113+
return accessToken;
114+
115+
} catch (Exception e) {
116+
throw new RuntimeException("刷新Coze Token失败: " + e.getMessage(), e);
117+
} finally {
118+
refreshLock.unlock();
119+
}
120+
}
121+
122+
/**
123+
* 生成JWT
124+
*/
125+
private String generateJWT() throws Exception {
126+
long currentTime = System.currentTimeMillis() / 1000;
127+
128+
// 构建Header
129+
Map<String, Object> headers = new HashMap<>();
130+
headers.put("alg", ALGORITHM);
131+
headers.put("typ", TOKEN_TYPE);
132+
headers.put("kid", publicKey);
133+
134+
// 构建Payload
135+
Map<String, Object> claims = new HashMap<>();
136+
claims.put("iss", oauthAppId); // OAuth应用ID
137+
claims.put("aud", COZE_API_ENDPOINT); // Coze API Endpoint
138+
claims.put("iat", currentTime); // 开始生效时间
139+
claims.put("exp", currentTime + 600); // JWT过期时间(10分钟后)
140+
claims.put("jti", UUID.randomUUID().toString()); // 随机字符串,防止重放攻击
141+
142+
// 可选参数
143+
// claims.put("session_name", "user_" + configId);
144+
145+
// 解析私钥
146+
PrivateKey key = parsePrivateKey(privateKey);
147+
148+
// 生成JWT
149+
String jwt = Jwts.builder()
150+
.setHeader(headers)
151+
.setClaims(claims)
152+
.signWith(key, SignatureAlgorithm.RS256)
153+
.compact();
154+
155+
return jwt;
156+
}
157+
158+
/**
159+
* 解析私钥
160+
*/
161+
private PrivateKey parsePrivateKey(String privateKeyStr) throws Exception {
162+
// 清理私钥字符串
163+
String cleanKey = privateKeyStr
164+
.replace("-----BEGIN PRIVATE KEY-----", "")
165+
.replace("-----END PRIVATE KEY-----", "")
166+
.replaceAll("\\s", "");
167+
168+
// 解码Base64
169+
byte[] keyBytes = Base64.getDecoder().decode(cleanKey);
170+
171+
// 创建私钥
172+
PKCS8EncodedKeySpec spec = new PKCS8EncodedKeySpec(keyBytes);
173+
KeyFactory keyFactory = KeyFactory.getInstance("RSA");
174+
return keyFactory.generatePrivate(spec);
175+
}
176+
177+
/**
178+
* 使用JWT请求访问令牌
179+
*/
180+
private String requestAccessToken(String jwt) throws Exception {
181+
// 构建请求头
182+
HttpHeaders headers = new HttpHeaders();
183+
headers.setContentType(MediaType.APPLICATION_JSON);
184+
headers.setBearerAuth(jwt);
185+
186+
// 构建请求体
187+
Map<String, Object> requestBody = new HashMap<>();
188+
requestBody.put("duration_seconds", DEFAULT_DURATION_SECONDS);
189+
requestBody.put("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer");
190+
191+
// 发送请求
192+
HttpEntity<Map<String, Object>> request = new HttpEntity<>(requestBody, headers);
193+
ResponseEntity<String> response = restTemplate.postForEntity(TOKEN_URL, request, String.class);
194+
195+
if (response.getStatusCode() == HttpStatus.OK) {
196+
// 直接使用fastjson解析响应
197+
JSONObject jsonResponse = JSON.parseObject(response.getBody());
198+
String accessToken = jsonResponse.getString("access_token");
199+
200+
if (accessToken == null || accessToken.isEmpty()) {
201+
throw new RuntimeException("响应中未找到access_token字段");
202+
}
203+
204+
return accessToken;
205+
} else {
206+
throw new RuntimeException("Coze API返回错误,HTTP状态码: " + response.getStatusCode() +
207+
", 响应: " + response.getBody());
208+
}
209+
}
210+
211+
@Override
212+
public boolean isTokenValid() {
213+
if (tokenCache == null) {
214+
return false;
215+
}
216+
217+
// 检查token是否过期
218+
return !tokenCache.isExpired();
219+
}
220+
221+
@Override
222+
public void clearTokenCache() {
223+
tokenCache = null;
224+
}
225+
226+
/**
227+
* 使用虚拟线程异步刷新token
228+
*/
229+
private void refreshTokenAsync() {
230+
Thread.startVirtualThread(() -> {
231+
try {
232+
refreshToken();
233+
} catch (Exception e) {
234+
logger.error("虚拟线程异步刷新Coze Token失败,configId: {}: {}", configId, e.getMessage(), e);
235+
}
236+
});
237+
}
238+
239+
/**
240+
* 检查是否需要清除缓存(超过24小时未使用)
241+
*/
242+
public boolean needsCacheCleanup() {
243+
return tokenCache != null && tokenCache.needsCacheCleanup();
244+
}
245+
246+
}

src/main/java/com/xiaozhi/entity/SysAgent.java

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,6 @@ public class SysAgent extends SysConfig {
1717
/** 智能体名称 */
1818
private String agentName;
1919

20-
/** 平台智能体空间ID */
21-
private String spaceId;
22-
2320
/** 平台智能体ID */
2421
private String botId;
2522

@@ -51,15 +48,6 @@ public SysAgent setAgentName(String agentName) {
5148
return this;
5249
}
5350

54-
public String getSpaceId() {
55-
return spaceId;
56-
}
57-
58-
public SysAgent setSpaceId(String spaceId) {
59-
this.spaceId = spaceId;
60-
return this;
61-
}
62-
6351
public String getBotId() {
6452
return botId;
6553
}

0 commit comments

Comments
 (0)