Skip to content

Commit c62433a

Browse files
committed
add logging and test cases
1 parent acc69b8 commit c62433a

File tree

12 files changed

+334
-25
lines changed

12 files changed

+334
-25
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,5 @@ log/
2929
target/
3030

3131
# ChatGPT-Java-API Specific Ignore
32-
.env
32+
.env
33+
debug.log

build.gradle.kts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,14 @@ dependencies {
2525
implementation("com.fasterxml.jackson.core:jackson-annotations:2.15.3")
2626
implementation("com.fasterxml.jackson.module:jackson-module-kotlin:2.15.3")
2727

28+
implementation("org.slf4j:slf4j-api:2.0.9")
29+
2830
implementation("org.jetbrains:annotations:24.0.1")
2931

3032
testImplementation("io.github.cdimascio:dotenv-kotlin:6.4.1")
3133
testImplementation("org.junit.jupiter:junit-jupiter:5.9.2")
34+
testImplementation("com.squareup.okhttp3:okhttp:4.9.2")
35+
testImplementation("com.squareup.okhttp3:mockwebserver:4.9.2")
3236
}
3337

3438
kotlin {

examples/build.gradle.kts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ dependencies {
1414
implementation("com.fasterxml.jackson.core:jackson-databind:2.15.3")
1515
implementation("io.github.cdimascio:dotenv-kotlin:6.4.1")
1616

17+
implementation("ch.qos.logback:logback-classic:1.4.11")
18+
1719
// https://mvnrepository.com/artifact/org.mariuszgromada.math/MathParser.org-mXparser
1820
// Used for tool tests
1921
implementation("org.mariuszgromada.math:MathParser.org-mXparser:5.2.1")
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
<configuration>
2+
<appender name="FILE" class="ch.qos.logback.core.FileAppender">
3+
<file>debug.log</file>
4+
<append>false</append>
5+
<encoder>
6+
<pattern>%date %level [%thread] %logger{10} %msg%n</pattern>
7+
</encoder>
8+
</appender>
9+
10+
<root level="DEBUG">
11+
<appender-ref ref="FILE"/>
12+
</root>
13+
</configuration>

src/main/kotlin/com/cjcrafter/openai/AzureOpenAI.kt

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,24 +13,23 @@ import org.jetbrains.annotations.ApiStatus
1313
*
1414
* This class constructs url in the form of: https://<azureBaseUrl>/openai/deployments/<modelName>/<endpoint>?api-version=<apiVersion>
1515
*
16-
* @property azureBaseUrl The base URL for the Azure OpenAI API. Usually https://<your_resource_group>.openai.azure.com
1716
* @property apiVersion The API version to use. Defaults to 2023-03-15-preview.
1817
* @property modelName The model name to use. This is the name of the model deployed to Azure.
1918
*/
2019
class AzureOpenAI @ApiStatus.Internal constructor(
2120
apiKey: String,
2221
organization: String? = null,
2322
client: OkHttpClient = OkHttpClient(),
24-
private val azureBaseUrl: String = "",
23+
baseUrl: String = "https://api.openai.com",
2524
private val apiVersion: String = "2023-03-15-preview",
2625
private val modelName: String = ""
27-
) : OpenAIImpl(apiKey, organization, client) {
26+
) : OpenAIImpl(apiKey, organization, client, baseUrl) {
2827

2928
override fun buildRequest(request: Any, endpoint: String): Request {
3029
val json = objectMapper.writeValueAsString(request)
3130
val body: RequestBody = json.toRequestBody(mediaType)
3231
return Request.Builder()
33-
.url("$azureBaseUrl/openai/deployments/$modelName/$endpoint?api-version=$apiVersion")
32+
.url("$baseUrl/openai/deployments/$modelName/$endpoint?api-version=$apiVersion")
3433
.addHeader("Content-Type", "application/json")
3534
.addHeader("api-key", apiKey)
3635
.apply { if (organization != null) addHeader("OpenAI-Organization", organization) }

src/main/kotlin/com/cjcrafter/openai/OpenAI.kt

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import com.cjcrafter.openai.completions.CompletionRequest
66
import com.cjcrafter.openai.completions.CompletionResponse
77
import com.cjcrafter.openai.completions.CompletionResponseChunk
88
import com.cjcrafter.openai.util.OpenAIDslMarker
9+
import com.fasterxml.jackson.annotation.JsonAutoDetect
910
import com.fasterxml.jackson.annotation.JsonInclude
1011
import com.fasterxml.jackson.databind.DeserializationFeature
1112
import com.fasterxml.jackson.databind.ObjectMapper
@@ -14,6 +15,7 @@ import com.fasterxml.jackson.module.kotlin.jacksonObjectMapper
1415
import okhttp3.OkHttpClient
1516
import org.jetbrains.annotations.ApiStatus
1617
import org.jetbrains.annotations.Contract
18+
import org.slf4j.LoggerFactory
1719

1820
interface OpenAI {
1921

@@ -91,46 +93,49 @@ interface OpenAI {
9193
protected var apiKey: String? = null
9294
protected var organization: String? = null
9395
protected var client: OkHttpClient = OkHttpClient()
96+
protected var baseUrl: String = "https://api.openai.com"
9497

9598
fun apiKey(apiKey: String) = apply { this.apiKey = apiKey }
9699
fun organization(organization: String?) = apply { this.organization = organization }
97100
fun client(client: OkHttpClient) = apply { this.client = client }
101+
fun baseUrl(baseUrl: String) = apply { this.baseUrl = baseUrl }
98102

99103
@Contract(pure = true)
100104
open fun build(): OpenAI {
101105
return OpenAIImpl(
102-
apiKey ?: throw IllegalStateException("apiKey must be defined to use OpenAI"),
103-
organization,
104-
client
106+
apiKey = apiKey ?: throw IllegalStateException("apiKey must be defined to use OpenAI"),
107+
organization = organization,
108+
client = client,
109+
baseUrl = baseUrl,
105110
)
106111
}
107112
}
108113

109114
@OpenAIDslMarker
110115
class AzureBuilder internal constructor(): Builder() {
111-
private var azureBaseUrl: String? = null
112116
private var apiVersion: String? = null
113117
private var modelName: String? = null
114118

115-
fun azureBaseUrl(azureBaseUrl: String) = apply { this.azureBaseUrl = azureBaseUrl }
116119
fun apiVersion(apiVersion: String) = apply { this.apiVersion = apiVersion }
117120
fun modelName(modelName: String) = apply { this.modelName = modelName }
118121

119122
@Contract(pure = true)
120123
override fun build(): OpenAI {
121124
return AzureOpenAI(
122-
apiKey ?: throw IllegalStateException("apiKey must be defined to use OpenAI"),
123-
organization,
124-
client,
125-
azureBaseUrl ?: throw IllegalStateException("azureBaseUrl must be defined for azure"),
126-
apiVersion ?: throw IllegalStateException("apiVersion must be defined for azure"),
127-
modelName ?: throw IllegalStateException("modelName must be defined for azure")
125+
apiKey = apiKey ?: throw IllegalStateException("apiKey must be defined to use OpenAI"),
126+
organization = organization,
127+
client = client,
128+
baseUrl = if (baseUrl == "https://api.openai.com") throw IllegalStateException("baseUrl must be set to an azure endpoint") else baseUrl,
129+
apiVersion = apiVersion ?: throw IllegalStateException("apiVersion must be defined for azure"),
130+
modelName = modelName ?: throw IllegalStateException("modelName must be defined for azure")
128131
)
129132
}
130133
}
131134

132135
companion object {
133136

137+
internal val logger = LoggerFactory.getLogger(OpenAI::class.java)
138+
134139
/**
135140
* Instantiates a builder for a default OpenAI instance. For Azure's
136141
* OpenAI, use [azureBuilder] instead.
@@ -155,6 +160,14 @@ interface OpenAI {
155160
setSerializationInclusion(JsonInclude.Include.NON_NULL)
156161
configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)
157162

163+
// By default, Jackson can serialize fields AND getters. We just want fields.
164+
setVisibility(serializationConfig.getDefaultVisibilityChecker()
165+
.withFieldVisibility(JsonAutoDetect.Visibility.ANY)
166+
.withGetterVisibility(JsonAutoDetect.Visibility.NONE)
167+
.withSetterVisibility(JsonAutoDetect.Visibility.NONE)
168+
.withCreatorVisibility(JsonAutoDetect.Visibility.NONE)
169+
)
170+
158171
// Register modules with custom serializers/deserializers
159172
val module = SimpleModule().apply {
160173
addSerializer(ToolChoice::class.java, ToolChoice.serializer())
@@ -180,10 +193,4 @@ interface OpenAI {
180193
consumer(chunk)
181194
}
182195
}
183-
}
184-
185-
@Contract(pure = true)
186-
fun openAI(init: OpenAI.Builder.() -> Unit) = OpenAI.builder().apply(init).build()
187-
188-
@Contract(pure = true)
189-
fun azureOpenAI(init: OpenAI.AzureBuilder.() -> Unit) = OpenAI.azureBuilder().apply(init).build()
196+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
package com.cjcrafter.openai
2+
3+
import org.jetbrains.annotations.Contract
4+
5+
/**
6+
* Builds an [OpenAI] instance using the default implementation.
7+
*/
8+
@Contract(pure = true)
9+
fun openAI(init: OpenAI.Builder.() -> Unit) = OpenAI.builder().apply(init).build()
10+
11+
/**
12+
* Builds an [OpenAI] instance using the Azure implementation.
13+
*/
14+
@Contract(pure = true)
15+
fun azureOpenAI(init: OpenAI.AzureBuilder.() -> Unit) = OpenAI.azureBuilder().apply(init).build()

src/main/kotlin/com/cjcrafter/openai/OpenAIImpl.kt

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ import java.io.IOException
1616
open class OpenAIImpl @ApiStatus.Internal constructor(
1717
protected val apiKey: String,
1818
protected val organization: String? = null,
19-
private val client: OkHttpClient = OkHttpClient()
19+
protected val client: OkHttpClient = OkHttpClient(),
20+
protected val baseUrl: String = "https://api.openai.com",
2021
): OpenAI {
2122
protected val mediaType = "application/json; charset=utf-8".toMediaType()
2223
protected val objectMapper = OpenAI.createObjectMapper()
@@ -25,7 +26,7 @@ open class OpenAIImpl @ApiStatus.Internal constructor(
2526
val json = objectMapper.writeValueAsString(request)
2627
val body: RequestBody = json.toRequestBody(mediaType)
2728
return Request.Builder()
28-
.url("https://api.openai.com/$endpoint")
29+
.url("$baseUrl/$endpoint")
2930
.addHeader("Content-Type", "application/json")
3031
.addHeader("Authorization", "Bearer $apiKey")
3132
.apply { if (organization != null) addHeader("OpenAI-Organization", organization) }
@@ -43,6 +44,7 @@ open class OpenAIImpl @ApiStatus.Internal constructor(
4344
val jsonReader = httpResponse.body?.byteStream()?.bufferedReader()
4445
?: throw IOException("Response body is null")
4546
val responseStr = jsonReader.readText()
47+
OpenAI.logger.debug(responseStr)
4648
return objectMapper.readValue(responseStr, responseType)
4749
}
4850

@@ -72,6 +74,8 @@ open class OpenAIImpl @ApiStatus.Internal constructor(
7274
var line: String?
7375
do {
7476
line = reader.readLine()
77+
OpenAI.logger.debug(line)
78+
7579
if (line == "data: [DONE]") {
7680
reader.close()
7781
return null
@@ -86,6 +90,7 @@ open class OpenAIImpl @ApiStatus.Internal constructor(
8690

8791
override fun next(): T {
8892
val line = nextLine ?: throw NoSuchElementException("No more lines")
93+
8994
currentResponse = if (currentResponse == null) {
9095
objectMapper.readValue(line, responseType)
9196
} else {
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
package com.cjcrafter.openai.chat
2+
3+
import com.cjcrafter.openai.OpenAI
4+
import com.cjcrafter.openai.chat.ChatMessage.Companion.toSystemMessage
5+
import org.intellij.lang.annotations.Language
6+
import org.junit.jupiter.api.Assertions.*
7+
import org.junit.jupiter.params.ParameterizedTest
8+
import org.junit.jupiter.params.provider.Arguments
9+
import org.junit.jupiter.params.provider.MethodSource
10+
import java.util.stream.Stream
11+
12+
class ChatRequestTest {
13+
14+
@ParameterizedTest
15+
@MethodSource("provide_serialize")
16+
fun `test deserialize to json`(obj: Any, json: String) {
17+
val objectMapper = OpenAI.createObjectMapper()
18+
val expected = objectMapper.readTree(json)
19+
val actual = objectMapper.readTree(objectMapper.writeValueAsString(obj))
20+
assertEquals(expected, actual)
21+
}
22+
23+
@ParameterizedTest
24+
@MethodSource("provide_serialize")
25+
fun `test serialize from json`(expected: Any, json: String) {
26+
val objectMapper = OpenAI.createObjectMapper()
27+
val actual = objectMapper.readValue(json, expected::class.java)
28+
assertEquals(expected, actual)
29+
}
30+
31+
companion object {
32+
@JvmStatic
33+
fun provide_serialize(): Stream<Arguments> {
34+
return buildList<Arguments> {
35+
36+
@Language("JSON")
37+
var json = """
38+
{
39+
"messages": [
40+
{
41+
"role": "system",
42+
"content": "Be as helpful as possible"
43+
}
44+
],
45+
"model": "gpt-3.5-turbo"
46+
}
47+
""".trimIndent()
48+
add(Arguments.of(
49+
ChatRequest.builder()
50+
.model("gpt-3.5-turbo")
51+
.messages(mutableListOf("Be as helpful as possible".toSystemMessage()))
52+
.build(),
53+
json
54+
))
55+
56+
json = """
57+
{
58+
"messages": [
59+
{
60+
"role": "system",
61+
"content": "Be as helpful as possible"
62+
},
63+
{
64+
"role": "user",
65+
"content": "What is 2 + 2?"
66+
}
67+
],
68+
"model": "gpt-3.5-turbo",
69+
"tools": [
70+
{
71+
"type": "function",
72+
"function": {
73+
"name": "solve_math_problem",
74+
"parameters": {
75+
"type": "object",
76+
"properties": {
77+
"equation": {
78+
"type": "string",
79+
"description": "The math problem for you to solve"
80+
}
81+
},
82+
"required": [
83+
"equation"
84+
]
85+
},
86+
"description": "Returns the result of a math problem as a double"
87+
}
88+
}
89+
]
90+
}
91+
""".trimIndent()
92+
add(Arguments.of(
93+
chatRequest {
94+
model("gpt-3.5-turbo")
95+
messages(mutableListOf(
96+
ChatMessage(ChatUser.SYSTEM, "Be as helpful as possible"),
97+
ChatMessage(ChatUser.USER, "What is 2 + 2?")
98+
))
99+
function {
100+
name("solve_math_problem")
101+
description("Returns the result of a math problem as a double")
102+
addStringParameter("equation", "The math problem for you to solve", true)
103+
}
104+
},
105+
json
106+
))
107+
108+
}.stream()
109+
}
110+
}
111+
}

0 commit comments

Comments
 (0)