From 29f69a18745197c62a2c569aa6bc87b3e578d82c Mon Sep 17 00:00:00 2001 From: aconeshana Date: Wed, 15 May 2024 18:38:20 +0800 Subject: [PATCH] feat(batch): add support to batch api --- .../com/theokanning/openai/batch/Batch.java | 143 ++++++++++++++++++ .../theokanning/openai/batch/BatchError.java | 31 ++++ .../openai/batch/BatchListRequest.java | 34 +++++ .../openai/batch/BatchListResult.java | 28 ++++ .../openai/batch/BatchRequestCounts.java | 26 ++++ .../openai/batch/CreateBatchRequest.java | 46 ++++++ .../com/theokanning/openai/batch/Status.java | 23 +++ .../theokanning/openai/client/OpenAiApi.java | 15 ++ .../openai/service/OpenAiService.java | 40 +++++ .../theokanning/openai/service/BatchTest.java | 91 +++++++++++ .../src/test/resources/batch-chat-data.jsonl | 4 + 11 files changed, 481 insertions(+) create mode 100644 api/src/main/java/com/theokanning/openai/batch/Batch.java create mode 100644 api/src/main/java/com/theokanning/openai/batch/BatchError.java create mode 100644 api/src/main/java/com/theokanning/openai/batch/BatchListRequest.java create mode 100644 api/src/main/java/com/theokanning/openai/batch/BatchListResult.java create mode 100644 api/src/main/java/com/theokanning/openai/batch/BatchRequestCounts.java create mode 100644 api/src/main/java/com/theokanning/openai/batch/CreateBatchRequest.java create mode 100644 api/src/main/java/com/theokanning/openai/batch/Status.java create mode 100644 service/src/test/java/com/theokanning/openai/service/BatchTest.java create mode 100644 service/src/test/resources/batch-chat-data.jsonl diff --git a/api/src/main/java/com/theokanning/openai/batch/Batch.java b/api/src/main/java/com/theokanning/openai/batch/Batch.java new file mode 100644 index 00000000..500c0b5d --- /dev/null +++ b/api/src/main/java/com/theokanning/openai/batch/Batch.java @@ -0,0 +1,143 @@ +package com.theokanning.openai.batch; + +import com.fasterxml.jackson.annotation.JsonProperty; +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; +import lombok.NonNull; + +import java.util.List; +import java.util.Map; + +/** + * @Author: acone.wu + * @date: 2024/5/15 13:51 + */ +@NoArgsConstructor +@AllArgsConstructor +@Data +public class Batch { + + @NonNull + private String id; + + /** + * The time frame within which the batch should be processed. + */ + @NonNull + @JsonProperty("completion_window") + private String completionWindow; + + /** + * The Unix timestamp (in seconds) for when the batch was created. + */ + @NonNull + @JsonProperty("created_at") + private Integer createdAt; + + /** + * The OpenAI API endpoint used by the batch. + */ + @NonNull + private String endpoint; + + /** + * The ID of the input file for the batch. + */ + @NonNull + @JsonProperty("input_file_id") + private String inputFileId; + + /** + * The object type, which is always `batch`. + */ + private String object = "batch"; + + /** + * The current status of the batch. + * */ + private Status status; + + /** + * The Unix timestamp (in seconds) for when the batch was cancelled. + * */ + @JsonProperty("cancelled_at") + private Integer cancelledAt; + + /** + * The Unix timestamp (in seconds) for when the batch started cancelling. + * */ + @JsonProperty("cancelling_at") + private Integer cancellingAt; + + /** + * The Unix timestamp (in seconds) for when the batch was completed. + * */ + @JsonProperty("completed_at") + private Integer completedAt; + + /** + * The ID of the file containing the outputs of requests with errors. + * */ + @JsonProperty("error_file_id") + private String errorFileId; + + /** + * Errors associated with batch processing. + * */ + private Errors errors; + + /** + * The Unix timestamp (in seconds) for when the batch expired. + * */ + @JsonProperty("expired_at") + private Integer expiredAt; + + /** + * The Unix timestamp (in seconds) for when the batch will expire. + * */ + @JsonProperty("expires_at") + private Integer expiresAt; + + /** + * The Unix timestamp (in seconds) for when the batch failed. + * */ + @JsonProperty("failed_at") + private Integer failedAt; + + /** + * The Unix timestamp (in seconds) for when the batch started finalizing. + * */ + @JsonProperty("finalizing_at") + private Integer finalizingAt; + + /** + * The Unix timestamp (in seconds) for when the batch started processing. + * */ + @JsonProperty("in_progress_at") + private Integer inProgressAt; + + /** + * Metadata for storing additional information about the object. + * */ + private Map metadata; + + /** + * The ID of the file containing the outputs of successfully executed requests. + * */ + @JsonProperty("output_file_id") + private String outputFileId; + + /** + * The request counts for different statuses within the batch. + * */ + @JsonProperty("request_counts") + private BatchRequestCounts requestCounts; + + @Data + public static class Errors { + private List data; + + private String object; + } +} diff --git a/api/src/main/java/com/theokanning/openai/batch/BatchError.java b/api/src/main/java/com/theokanning/openai/batch/BatchError.java new file mode 100644 index 00000000..d79371d3 --- /dev/null +++ b/api/src/main/java/com/theokanning/openai/batch/BatchError.java @@ -0,0 +1,31 @@ +package com.theokanning.openai.batch; + +import lombok.Data; + +/** + * @Author: acone.wu + * @date: 2024/5/15 13:50 + */ +@Data +public class BatchError { + + /** + * An error code identifying the error type. + */ + String code; + + /** + * The line number of the input file where the error occurred, if applicable + */ + Integer line; + + /** + * A human-readable message providing more details about the error. + */ + String message; + + /** + * The name of the parameter that caused the error, if applicable. + */ + String param; +} diff --git a/api/src/main/java/com/theokanning/openai/batch/BatchListRequest.java b/api/src/main/java/com/theokanning/openai/batch/BatchListRequest.java new file mode 100644 index 00000000..3e091203 --- /dev/null +++ b/api/src/main/java/com/theokanning/openai/batch/BatchListRequest.java @@ -0,0 +1,34 @@ +package com.theokanning.openai.batch; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +/** + * @Author: acone.wu + * @date: 2024/5/15 16:12 + */ +@NoArgsConstructor +@AllArgsConstructor +@Builder +@Data +public class BatchListRequest { + + /** + * A cursor for use in pagination. + *

+ * `after` is an object ID that defines your place in the list. For instance, if + * you make a list request and receive 100 objects, ending with obj_foo, your + * subsequent call can include after=obj_foo in order to fetch the next page of the + * list. + */ + private String after; + + /** + * A limit on the number of objects to be returned. + *

+ * Limit can range between 1 and 100, and the default is 20. + */ + private Integer limit; +} diff --git a/api/src/main/java/com/theokanning/openai/batch/BatchListResult.java b/api/src/main/java/com/theokanning/openai/batch/BatchListResult.java new file mode 100644 index 00000000..308ca51b --- /dev/null +++ b/api/src/main/java/com/theokanning/openai/batch/BatchListResult.java @@ -0,0 +1,28 @@ +package com.theokanning.openai.batch; + +import com.fasterxml.jackson.annotation.JsonProperty; +import lombok.Data; + +import java.util.List; + +/** + * @Author: acone.wu + * @date: 2024/5/15 16:07 + */ +@Data +public class BatchListResult { + + private String object; + + private List data; + + @JsonProperty("first_id") + private String firstId; + + @JsonProperty("last_id") + private String lastId; + + @JsonProperty("has_more") + private Boolean hasMore; + +} diff --git a/api/src/main/java/com/theokanning/openai/batch/BatchRequestCounts.java b/api/src/main/java/com/theokanning/openai/batch/BatchRequestCounts.java new file mode 100644 index 00000000..311e5031 --- /dev/null +++ b/api/src/main/java/com/theokanning/openai/batch/BatchRequestCounts.java @@ -0,0 +1,26 @@ +package com.theokanning.openai.batch; + +import lombok.Data; + +/** + * @Author: acone.wu + * @date: 2024/5/15 15:36 + */ +@Data +public class BatchRequestCounts { + + /** + * Number of requests that have been completed successfully. + * */ + private Integer completed; + + /** + * Number of requests that have failed. + * */ + private Integer failed; + + /** + * Total number of requests in the batch. + * */ + private Integer total; +} diff --git a/api/src/main/java/com/theokanning/openai/batch/CreateBatchRequest.java b/api/src/main/java/com/theokanning/openai/batch/CreateBatchRequest.java new file mode 100644 index 00000000..c6dabd50 --- /dev/null +++ b/api/src/main/java/com/theokanning/openai/batch/CreateBatchRequest.java @@ -0,0 +1,46 @@ +package com.theokanning.openai.batch; + +import com.fasterxml.jackson.annotation.JsonProperty; +import lombok.*; + +import java.util.Map; + +/** + * @Author: acone.wu + * @date: 2024/5/14 21:11 + */ +@NoArgsConstructor +@AllArgsConstructor +@Builder +@Data +public class CreateBatchRequest { + + /** + * The ID of an uploaded file that contains requests for the new batch. + *

+ * See upload file for how to upload a file. + *

+ * Your input file must be formatted as a JSONL file, and must be uploaded with the purpose batch. The file can contain up to 50,000 requests, and can be up to 100 MB in size. + */ + @NonNull + @JsonProperty("input_file_id") + String inputFileId; + + /** + * The endpoint to be used for all requests in the batch. Currently /v1/chat/completions, /v1/embeddings, and /v1/completions are supported. Note that /v1/embeddings batches are also restricted to a maximum of 50,000 embedding inputs across all requests in the batch. + */ + @NonNull + String endpoint; + + /** + * The time frame within which the batch should be processed. Currently only 24h is supported. + */ + @NonNull + @JsonProperty("completion_window") + String compWindow = "24h"; + + /** + * Optional custom metadata for the batch. + */ + Map metadata; +} diff --git a/api/src/main/java/com/theokanning/openai/batch/Status.java b/api/src/main/java/com/theokanning/openai/batch/Status.java new file mode 100644 index 00000000..ff62b11d --- /dev/null +++ b/api/src/main/java/com/theokanning/openai/batch/Status.java @@ -0,0 +1,23 @@ +package com.theokanning.openai.batch; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonValue; + +/** + * @Author: acone.wu + * @date: 2024/5/15 15:40 + */ +public enum Status { + VALIDATING, FAILED, IN_PROGRESS, FINALIZING, COMPLETED, EXPIRED, CANCELLING, CANCELLED + ; + + @JsonCreator + public static Status fromValue(String value) { + return Status.valueOf(value.toUpperCase()); + } + + @JsonValue + public String toValue() { + return this.name().toLowerCase(); + } +} diff --git a/client/src/main/java/com/theokanning/openai/client/OpenAiApi.java b/client/src/main/java/com/theokanning/openai/client/OpenAiApi.java index 7342e953..1f6b535e 100644 --- a/client/src/main/java/com/theokanning/openai/client/OpenAiApi.java +++ b/client/src/main/java/com/theokanning/openai/client/OpenAiApi.java @@ -7,6 +7,9 @@ import com.theokanning.openai.audio.CreateSpeechRequest; import com.theokanning.openai.audio.TranscriptionResult; import com.theokanning.openai.audio.TranslationResult; +import com.theokanning.openai.batch.Batch; +import com.theokanning.openai.batch.BatchListResult; +import com.theokanning.openai.batch.CreateBatchRequest; import com.theokanning.openai.billing.BillingUsage; import com.theokanning.openai.billing.Subscription; import com.theokanning.openai.completion.CompletionRequest; @@ -321,4 +324,16 @@ public interface OpenAiApi { @Headers("OpenAI-Beta: assistants=v1") @GET("/v1/threads/{thread_id}/runs/{run_id}/steps") Single> listRunSteps(@Path("thread_id") String threadId, @Path("run_id") String runId, @QueryMap Map listSearchParameters); + + @POST("/v1/batches") + Single createBatch(@Body CreateBatchRequest createBatchRequest); + + @GET("/v1/batches/{batch_id}") + Single retrieveBatch(@Path("batch_id") String batchId); + + @GET("/v1/batches") + Single listBatches(@QueryMap Map listSearchParameters); + + @POST("/v1/batches/{batch_id}/cancel") + Single cancelBatch(@Path("batch_id") String batchId); } diff --git a/service/src/main/java/com/theokanning/openai/service/OpenAiService.java b/service/src/main/java/com/theokanning/openai/service/OpenAiService.java index 52ab6b0f..d3752bcc 100644 --- a/service/src/main/java/com/theokanning/openai/service/OpenAiService.java +++ b/service/src/main/java/com/theokanning/openai/service/OpenAiService.java @@ -9,6 +9,10 @@ import com.theokanning.openai.*; import com.theokanning.openai.assistants.*; import com.theokanning.openai.audio.*; +import com.theokanning.openai.batch.Batch; +import com.theokanning.openai.batch.BatchListRequest; +import com.theokanning.openai.batch.BatchListResult; +import com.theokanning.openai.batch.CreateBatchRequest; import com.theokanning.openai.billing.BillingUsage; import com.theokanning.openai.billing.Subscription; import com.theokanning.openai.client.OpenAiApi; @@ -659,4 +663,40 @@ public BillingUsage billingUsage(@NotNull LocalDate starDate, @NotNull LocalDate return billingUsage.blockingGet(); } + /** + * Creates and executes a batch from an uploaded file of requests + */ + public Batch createBatch(@NotNull CreateBatchRequest request) { + Single batch = api.createBatch(request); + return batch.blockingGet(); + } + + /** + * Retrieves a batch. + */ + public Batch retrieveBatch(@NotNull String batchId) { + Single batch = api.retrieveBatch(batchId); + return batch.blockingGet(); + } + + /** + * List your organization's batches. + */ + public BatchListResult listBatch(BatchListRequest batchListParameters) { + Map search = new HashMap<>(); + if (batchListParameters != null) { + ObjectMapper mapper = defaultObjectMapper(); + search = mapper.convertValue(batchListParameters, Map.class); + } + Single batches = api.listBatches(search); + return batches.blockingGet(); + } + + /** + * Cancels an in-progress batch. + */ + public Batch cancel(@NotNull String batchId) { + Single batch = api.cancelBatch(batchId); + return batch.blockingGet(); + } } diff --git a/service/src/test/java/com/theokanning/openai/service/BatchTest.java b/service/src/test/java/com/theokanning/openai/service/BatchTest.java new file mode 100644 index 00000000..db7bb6f1 --- /dev/null +++ b/service/src/test/java/com/theokanning/openai/service/BatchTest.java @@ -0,0 +1,91 @@ +package com.theokanning.openai.service; + +import com.theokanning.openai.DeleteResult; +import com.theokanning.openai.batch.*; +import com.theokanning.openai.file.File; +import org.junit.jupiter.api.MethodOrderer; +import org.junit.jupiter.api.Order; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestMethodOrder; + +import java.util.concurrent.TimeUnit; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * @Author: acone.wu + * @date: 2024/5/15 15:50 + */ +@TestMethodOrder(MethodOrderer.OrderAnnotation.class) +public class BatchTest { + + static String filePath = "src/test/resources/batch-chat-data.jsonl"; + + String token = System.getenv("OPENAI_TOKEN"); + OpenAiService service = new OpenAiService(token); + static String fileId; + + static Batch batch; + + @Test + @Order(1) + void uploadFile() throws Exception { + File file = service.uploadFile("batch", filePath); + fileId = file.getId(); + + assertEquals("batch", file.getPurpose()); + assertEquals(filePath, file.getFilename()); + + // wait for file to be processed + TimeUnit.SECONDS.sleep(10); + } + + @Test + @Order(2) + void createBatch() { + CreateBatchRequest request = CreateBatchRequest.builder() + .inputFileId(fileId) + .endpoint("/v1/chat/completions") + .compWindow("24h") + .build(); + batch = service.createBatch(request); + + assertEquals("/v1/chat/completions", batch.getEndpoint()); + assertEquals("24h", batch.getCompletionWindow()); + assertEquals(Status.VALIDATING, batch.getStatus()); + } + + @Test + @Order(3) + void retrieveBatch() { + Batch detail = service.retrieveBatch(batch.getId()); + + assertEquals("/v1/chat/completions", detail.getEndpoint()); + assertEquals("24h", detail.getCompletionWindow()); + } + + @Test + @Order(4) + void listBatches() { + BatchListRequest request = BatchListRequest.builder().build(); + BatchListResult batchListResult = service.listBatch(request); + + assertTrue(batchListResult.getData().stream().anyMatch(b -> batch.getId().equals(b.getId()))); + } + + @Test + @Order(4) + void cancelBatch() { + Batch cancelled = service.cancel(batch.getId()); + + assertEquals(Status.CANCELLING, cancelled.getStatus()); + } + + @Test + @Order(5) + void deleteFile() { + DeleteResult deleteResult = service.deleteFile(fileId); + assertTrue(deleteResult.isDeleted()); + } +} diff --git a/service/src/test/resources/batch-chat-data.jsonl b/service/src/test/resources/batch-chat-data.jsonl new file mode 100644 index 00000000..e1439918 --- /dev/null +++ b/service/src/test/resources/batch-chat-data.jsonl @@ -0,0 +1,4 @@ +{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "What's the capital of France?"}]}} +{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "Who wrote 'Romeo and Juliet'?"}]}} +{"custom_id": "request-3", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "How far is the Moon from Earth?"}]}} +{"custom_id": "request-4", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "What's the capital of France?"}]}}