@@ -59,6 +59,38 @@ func TestCompletions(t *testing.T) {
5959 checks .NoError (t , err , "CreateCompletion error" )
6060}
6161
62+ // TestMultiplePromptsCompletionsWrong Tests the completions endpoint of the API using the mocked server
63+ // where the completions requests has a list of prompts with wrong type.
64+ func TestMultiplePromptsCompletionsWrong (t * testing.T ) {
65+ client , server , teardown := setupOpenAITestServer ()
66+ defer teardown ()
67+ server .RegisterHandler ("/v1/completions" , handleCompletionEndpoint )
68+ req := openai.CompletionRequest {
69+ MaxTokens : 5 ,
70+ Model : "ada" ,
71+ Prompt : []interface {}{"Lorem ipsum" , 9 },
72+ }
73+ _ , err := client .CreateCompletion (context .Background (), req )
74+ if ! errors .Is (err , openai .ErrCompletionRequestPromptTypeNotSupported ) {
75+ t .Fatalf ("CreateCompletion should return ErrCompletionRequestPromptTypeNotSupported, but returned: %v" , err )
76+ }
77+ }
78+
79+ // TestMultiplePromptsCompletions Tests the completions endpoint of the API using the mocked server
80+ // where the completions requests has a list of prompts.
81+ func TestMultiplePromptsCompletions (t * testing.T ) {
82+ client , server , teardown := setupOpenAITestServer ()
83+ defer teardown ()
84+ server .RegisterHandler ("/v1/completions" , handleCompletionEndpoint )
85+ req := openai.CompletionRequest {
86+ MaxTokens : 5 ,
87+ Model : "ada" ,
88+ Prompt : []interface {}{"Lorem ipsum" , "Lorem ipsum" },
89+ }
90+ _ , err := client .CreateCompletion (context .Background (), req )
91+ checks .NoError (t , err , "CreateCompletion error" )
92+ }
93+
6294// handleCompletionEndpoint Handles the completion endpoint by the test server.
6395func handleCompletionEndpoint (w http.ResponseWriter , r * http.Request ) {
6496 var err error
@@ -87,24 +119,50 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
87119 if n == 0 {
88120 n = 1
89121 }
122+ // Handle different types of prompts: single string or list of strings
123+ prompts := []string {}
124+ switch v := completionReq .Prompt .(type ) {
125+ case string :
126+ prompts = append (prompts , v )
127+ case []interface {}:
128+ for _ , item := range v {
129+ if str , ok := item .(string ); ok {
130+ prompts = append (prompts , str )
131+ }
132+ }
133+ default :
134+ http .Error (w , "Invalid prompt type" , http .StatusBadRequest )
135+ return
136+ }
137+
90138 for i := 0 ; i < n ; i ++ {
91- // generate a random string of length completionReq.Length
92- completionStr := strings .Repeat ("a" , completionReq .MaxTokens )
93- if completionReq .Echo {
94- completionStr = completionReq .Prompt .(string ) + completionStr
139+ for _ , prompt := range prompts {
140+ // Generate a random string of length completionReq.MaxTokens
141+ completionStr := strings .Repeat ("a" , completionReq .MaxTokens )
142+ if completionReq .Echo {
143+ completionStr = prompt + completionStr
144+ }
145+
146+ res .Choices = append (res .Choices , openai.CompletionChoice {
147+ Text : completionStr ,
148+ Index : len (res .Choices ),
149+ })
95150 }
96- res .Choices = append (res .Choices , openai.CompletionChoice {
97- Text : completionStr ,
98- Index : i ,
99- })
100151 }
101- inputTokens := numTokens (completionReq .Prompt .(string )) * n
102- completionTokens := completionReq .MaxTokens * n
152+
153+ inputTokens := 0
154+ for _ , prompt := range prompts {
155+ inputTokens += numTokens (prompt )
156+ }
157+ inputTokens *= n
158+ completionTokens := completionReq .MaxTokens * len (prompts ) * n
103159 res .Usage = openai.Usage {
104160 PromptTokens : inputTokens ,
105161 CompletionTokens : completionTokens ,
106162 TotalTokens : inputTokens + completionTokens ,
107163 }
164+
165+ // Serialize the response and send it back
108166 resBytes , _ = json .Marshal (res )
109167 fmt .Fprintln (w , string (resBytes ))
110168}
0 commit comments