Skip to content

Commit 2d5be6f

Browse files
authored
Merge branch 'main' into 388-chat-completition-response-format
2 parents bd3c204 + 3d60d6e commit 2d5be6f

File tree

7 files changed

+195
-8
lines changed

7 files changed

+195
-8
lines changed
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
package com.theokanning.openai.assistants;
2+
3+
import lombok.AllArgsConstructor;
4+
import lombok.Builder;
5+
import lombok.Data;
6+
import lombok.NoArgsConstructor;
7+
8+
import java.util.Map;
9+
10+
/**
11+
* @description:
12+
* @author: vacuity
13+
* @create: 2023-11-20 10:09
14+
**/
15+
16+
17+
@Builder
18+
@NoArgsConstructor
19+
@AllArgsConstructor
20+
@Data
21+
public class AssistantFunction {
22+
23+
private String description;
24+
25+
private String name;
26+
27+
private Map<String, Object> parameters;
28+
}

api/src/main/java/com/theokanning/openai/assistants/Tool.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
package com.theokanning.openai.assistants;
22

3-
import com.theokanning.openai.completion.chat.ChatFunction;
43
import lombok.AllArgsConstructor;
54
import lombok.Data;
65
import lombok.NoArgsConstructor;
@@ -17,5 +16,5 @@ public class Tool {
1716
/**
1817
* Function definition, only used if type is "function"
1918
*/
20-
ChatFunction function;
19+
AssistantFunction function;
2120
}

api/src/main/java/com/theokanning/openai/runs/SubmitToolOutputsRequest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,5 @@
2222
public class SubmitToolOutputsRequest {
2323

2424
@JsonProperty("tool_outputs")
25-
private List<SubmitToolOutputRequestItem> tool_outputs;
25+
private List<SubmitToolOutputRequestItem> toolOutputs;
2626
}

client/src/main/java/com/theokanning/openai/client/OpenAiApi.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,8 @@ public interface OpenAiApi {
298298

299299
@Headers("OpenAI-Beta: assistants=v1")
300300
@GET("/v1/threads/{thread_id}/runs")
301-
Single<OpenAiResponse<Run>> listRuns(@Path("thread_id") String threadId, @Body ListSearchParameters listSearchParameters);
301+
Single<OpenAiResponse<Run>> listRuns(@Path("thread_id") String threadId, @QueryMap Map<String, String> listSearchParameters);
302+
302303

303304
@Headers("OpenAI-Beta: assistants=v1")
304305
@POST("/v1/threads/{thread_id}/runs/{run_id}/submit_tool_outputs")
@@ -319,5 +320,5 @@ public interface OpenAiApi {
319320

320321
@Headers("OpenAI-Beta: assistants=v1")
321322
@GET("/v1/threads/{thread_id}/runs/{run_id}/steps")
322-
Single<OpenAiResponse<RunStep>> listRunSteps(@Path("thread_id") String threadId, @Path("run_id") String runId, @Body ListSearchParameters listSearchParameters);
323+
Single<OpenAiResponse<RunStep>> listRunSteps(@Path("thread_id") String threadId, @Path("run_id") String runId, @QueryMap Map<String, String> listSearchParameters);
323324
}

gradle.properties

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
GROUP=com.theokanning.openai-gpt3-java
2-
VERSION_NAME=0.18.1
2+
VERSION_NAME=0.18.2
33

44
POM_URL=https://github.com/theokanning/openai-java
55
POM_SCM_URL=https://github.com/theokanning/openai-java

service/src/main/java/com/theokanning/openai/service/OpenAiService.java

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
import java.io.IOException;
6060
import java.time.Duration;
6161
import java.time.LocalDate;
62+
import java.util.HashMap;
6263
import java.util.List;
6364
import java.util.Map;
6465
import java.util.Objects;
@@ -471,7 +472,12 @@ public Run modifyRun(String threadId, String runId, Map<String, String> metadata
471472
}
472473

473474
public OpenAiResponse<Run> listRuns(String threadId, ListSearchParameters listSearchParameters) {
474-
return execute(api.listRuns(threadId, listSearchParameters));
475+
Map<String, String> search = new HashMap<>();
476+
if (listSearchParameters != null) {
477+
ObjectMapper mapper = defaultObjectMapper();
478+
search = mapper.convertValue(listSearchParameters, Map.class);
479+
}
480+
return execute(api.listRuns(threadId, search));
475481
}
476482

477483
public Run submitToolOutputs(String threadId, String runId, SubmitToolOutputsRequest submitToolOutputsRequest) {
@@ -491,7 +497,12 @@ public RunStep retrieveRunStep(String threadId, String runId, String stepId) {
491497
}
492498

493499
public OpenAiResponse<RunStep> listRunSteps(String threadId, String runId, ListSearchParameters listSearchParameters) {
494-
return execute(api.listRunSteps(threadId, runId, listSearchParameters));
500+
Map<String, String> search = new HashMap<>();
501+
if (listSearchParameters != null) {
502+
ObjectMapper mapper = defaultObjectMapper();
503+
search = mapper.convertValue(listSearchParameters, Map.class);
504+
}
505+
return execute(api.listRunSteps(threadId, runId, search));
495506
}
496507

497508
/**
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
package com.theokanning.openai.service;
2+
3+
import com.fasterxml.jackson.annotation.JsonInclude;
4+
import com.fasterxml.jackson.core.JsonProcessingException;
5+
import com.fasterxml.jackson.core.type.TypeReference;
6+
import com.fasterxml.jackson.databind.DeserializationFeature;
7+
import com.fasterxml.jackson.databind.ObjectMapper;
8+
import com.fasterxml.jackson.databind.PropertyNamingStrategy;
9+
import com.theokanning.openai.ListSearchParameters;
10+
import com.theokanning.openai.OpenAiResponse;
11+
import com.theokanning.openai.assistants.Assistant;
12+
import com.theokanning.openai.assistants.AssistantFunction;
13+
import com.theokanning.openai.assistants.AssistantRequest;
14+
import com.theokanning.openai.assistants.AssistantToolsEnum;
15+
import com.theokanning.openai.assistants.Tool;
16+
import com.theokanning.openai.completion.chat.ChatCompletionRequest;
17+
import com.theokanning.openai.completion.chat.ChatFunction;
18+
import com.theokanning.openai.completion.chat.ChatFunctionCall;
19+
import com.theokanning.openai.messages.Message;
20+
import com.theokanning.openai.messages.MessageRequest;
21+
import com.theokanning.openai.runs.RequiredAction;
22+
import com.theokanning.openai.runs.Run;
23+
import com.theokanning.openai.runs.RunCreateRequest;
24+
import com.theokanning.openai.runs.RunStep;
25+
import com.theokanning.openai.runs.SubmitToolOutputRequestItem;
26+
import com.theokanning.openai.runs.SubmitToolOutputs;
27+
import com.theokanning.openai.runs.SubmitToolOutputsRequest;
28+
import com.theokanning.openai.runs.ToolCall;
29+
import com.theokanning.openai.threads.Thread;
30+
import com.theokanning.openai.threads.ThreadRequest;
31+
import com.theokanning.openai.utils.TikTokensUtil;
32+
import org.junit.jupiter.api.Test;
33+
34+
import java.time.Duration;
35+
import java.util.ArrayList;
36+
import java.util.List;
37+
import java.util.Map;
38+
import java.util.Objects;
39+
40+
import static org.junit.jupiter.api.Assertions.assertEquals;
41+
import static org.junit.jupiter.api.Assertions.assertNotNull;
42+
43+
class AssistantFunctionTest {
44+
String token = System.getenv("OPENAI_TOKEN");
45+
OpenAiService service = new OpenAiService(token, Duration.ofMinutes(1));
46+
47+
@Test
48+
void createRetrieveRun() throws JsonProcessingException {
49+
50+
ObjectMapper mapper = new ObjectMapper();
51+
mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
52+
mapper.setSerializationInclusion(JsonInclude.Include.NON_NULL);
53+
mapper.setPropertyNamingStrategy(PropertyNamingStrategy.SNAKE_CASE);
54+
mapper.addMixIn(ChatFunction.class, ChatFunctionMixIn.class);
55+
mapper.addMixIn(ChatCompletionRequest.class, ChatCompletionRequestMixIn.class);
56+
mapper.addMixIn(ChatFunctionCall.class, ChatFunctionCallMixIn.class);
57+
58+
String funcDef = "{\n" +
59+
" \"type\": \"object\",\n" +
60+
" \"properties\": {\n" +
61+
" \"location\": {\n" +
62+
" \"type\": \"string\",\n" +
63+
" \"description\": \"The city and state, e.g. San Francisco, CA\"\n" +
64+
" },\n" +
65+
" \"unit\": {\n" +
66+
" \"type\": \"string\",\n" +
67+
" \"enum\": [\"celsius\", \"fahrenheit\"]\n" +
68+
" }\n" +
69+
" },\n" +
70+
" \"required\": [\"location\"]\n" +
71+
"}";
72+
Map<String, Object> funcParameters = mapper.readValue(funcDef, new TypeReference<Map<String, Object>>() {});
73+
AssistantFunction function = AssistantFunction.builder()
74+
.name("weather_reporter")
75+
.description("Get the current weather of a location")
76+
.parameters(funcParameters)
77+
.build();
78+
79+
List<Tool> toolList = new ArrayList<>();
80+
Tool funcTool = new Tool(AssistantToolsEnum.FUNCTION, function);
81+
toolList.add(funcTool);
82+
83+
84+
AssistantRequest assistantRequest = AssistantRequest.builder()
85+
.model(TikTokensUtil.ModelEnum.GPT_4_1106_preview.getName())
86+
.name("MATH_TUTOR")
87+
.instructions("You are a personal Math Tutor.")
88+
.tools(toolList)
89+
.build();
90+
Assistant assistant = service.createAssistant(assistantRequest);
91+
92+
ThreadRequest threadRequest = ThreadRequest.builder()
93+
.build();
94+
Thread thread = service.createThread(threadRequest);
95+
96+
MessageRequest messageRequest = MessageRequest.builder()
97+
.content("What's the weather of Xiamen?")
98+
.build();
99+
100+
Message message = service.createMessage(thread.getId(), messageRequest);
101+
102+
RunCreateRequest runCreateRequest = RunCreateRequest.builder()
103+
.assistantId(assistant.getId())
104+
.build();
105+
106+
Run run = service.createRun(thread.getId(), runCreateRequest);
107+
assertNotNull(run);
108+
109+
Run retrievedRun = service.retrieveRun(thread.getId(), run.getId());
110+
while (!(retrievedRun.getStatus().equals("completed"))
111+
&& !(retrievedRun.getStatus().equals("failed"))
112+
&& !(retrievedRun.getStatus().equals("requires_action"))){
113+
retrievedRun = service.retrieveRun(thread.getId(), run.getId());
114+
}
115+
if (retrievedRun.getStatus().equals("requires_action")) {
116+
RequiredAction requiredAction = retrievedRun.getRequiredAction();
117+
System.out.println("requiredAction");
118+
System.out.println(mapper.writeValueAsString(requiredAction));
119+
List<ToolCall> toolCalls = requiredAction.getSubmitToolOutputs().getToolCalls();
120+
ToolCall toolCall = toolCalls.get(0);
121+
String toolCallId = toolCall.getId();
122+
123+
SubmitToolOutputRequestItem toolOutputRequestItem = SubmitToolOutputRequestItem.builder()
124+
.toolCallId(toolCallId)
125+
.output("sunny")
126+
.build();
127+
List<SubmitToolOutputRequestItem> toolOutputRequestItems = new ArrayList<>();
128+
toolOutputRequestItems.add(toolOutputRequestItem);
129+
SubmitToolOutputsRequest submitToolOutputsRequest = SubmitToolOutputsRequest.builder()
130+
.toolOutputs(toolOutputRequestItems)
131+
.build();
132+
retrievedRun = service.submitToolOutputs(retrievedRun.getThreadId(), retrievedRun.getId(), submitToolOutputsRequest);
133+
134+
while (!(retrievedRun.getStatus().equals("completed"))
135+
&& !(retrievedRun.getStatus().equals("failed"))
136+
&& !(retrievedRun.getStatus().equals("requires_action"))){
137+
retrievedRun = service.retrieveRun(thread.getId(), run.getId());
138+
}
139+
140+
OpenAiResponse<Message> response = service.listMessages(thread.getId());
141+
142+
List<Message> messages = response.getData();
143+
144+
System.out.println(mapper.writeValueAsString(messages));
145+
146+
}
147+
}
148+
}

0 commit comments

Comments
 (0)