Skip to content

Commit 33605ac

Browse files
authored
Fix/enhance go proxy buffering (#141)
* docs: add Chinese readme * fix: enhance json buffering * fix: code cleanup
1 parent 09e9720 commit 33605ac

File tree

1 file changed

+132
-14
lines changed
  • src/emd/cfn/shared/openai_router

1 file changed

+132
-14
lines changed

src/emd/cfn/shared/openai_router/main.go

Lines changed: 132 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,8 @@ func httpProxyHandler(c *gin.Context, endpointURL string, inputBytes []byte) {
284284
}
285285
_ = json.Unmarshal(inputBytes, &streamRequest) // Best effort check
286286

287+
// log.Printf("[DEBUG] ECS proxy handler - URL: %s, Streaming: %v", endpointURL, streamRequest.Stream)
288+
287289
client := &http.Client{
288290
Timeout: 15 * time.Minute,
289291
Transport: &http.Transport{
@@ -294,6 +296,8 @@ func httpProxyHandler(c *gin.Context, endpointURL string, inputBytes []byte) {
294296
// Add buffer sizes to handle large responses
295297
ReadBufferSize: 32 * 1024, // 32KB read buffer
296298
WriteBufferSize: 32 * 1024, // 32KB write buffer
299+
// Add response header timeout for better reliability
300+
ResponseHeaderTimeout: 30 * time.Second,
297301
},
298302
}
299303

@@ -334,10 +338,94 @@ func httpProxyHandler(c *gin.Context, endpointURL string, inputBytes []byte) {
334338
return
335339
}
336340

337-
// Stream the response
341+
// Create channel for streaming responses (same pattern as SageMaker)
342+
stream := make(chan []byte)
343+
closeOnce := sync.Once{}
344+
345+
// Start streaming in a goroutine (same pattern as SageMaker)
346+
go func() {
347+
defer closeOnce.Do(func() { close(stream) })
348+
349+
// Buffer for accumulating partial chunks (same as SageMaker)
350+
var buffer strings.Builder
351+
readBuffer := make([]byte, 4096)
352+
353+
for {
354+
n, err := resp.Body.Read(readBuffer)
355+
if n > 0 {
356+
chunk := string(readBuffer[:n])
357+
// log.Printf("[DEBUG] ECS received chunk: %s", chunk)
358+
359+
// Add chunk to buffer
360+
buffer.WriteString(chunk)
361+
bufferContent := buffer.String()
362+
363+
// Process complete lines from buffer (same logic as SageMaker)
364+
for strings.Contains(bufferContent, "\n") {
365+
lines := strings.SplitN(bufferContent, "\n", 2)
366+
if len(lines) < 2 {
367+
break
368+
}
369+
370+
line := strings.TrimSpace(lines[0])
371+
if line != "" {
372+
// Check if it's SSE data line with JSON
373+
if strings.HasPrefix(line, "data: ") {
374+
jsonPart := strings.TrimPrefix(line, "data: ")
375+
if jsonPart != "[DONE]" && jsonPart != "" {
376+
// Validate JSON content
377+
if !json.Valid([]byte(jsonPart)) {
378+
log.Printf("[WARNING] Invalid JSON in ECS SSE: %s", jsonPart)
379+
// Skip invalid JSON to prevent client parsing errors
380+
bufferContent = lines[1]
381+
buffer.Reset()
382+
buffer.WriteString(bufferContent)
383+
continue
384+
}
385+
}
386+
}
387+
388+
// Forward the complete line as-is
389+
stream <- []byte(line + "\n")
390+
} else {
391+
// Forward empty lines (important for SSE format)
392+
stream <- []byte("\n")
393+
}
394+
395+
// Update buffer with remaining content
396+
bufferContent = lines[1]
397+
buffer.Reset()
398+
buffer.WriteString(bufferContent)
399+
}
400+
}
401+
402+
if err != nil {
403+
if err == io.EOF {
404+
// Process any remaining data in buffer
405+
if buffer.Len() > 0 {
406+
remaining := strings.TrimSpace(buffer.String())
407+
if remaining != "" {
408+
stream <- []byte(remaining + "\n")
409+
}
410+
}
411+
return // End of stream
412+
}
413+
log.Printf("[ERROR] Error reading from ECS stream: %v", err)
414+
return
415+
}
416+
}
417+
}()
418+
419+
// Stream responses to client (same pattern as SageMaker)
338420
c.Stream(func(w io.Writer) bool {
339-
_, err := io.Copy(w, resp.Body)
340-
return err == nil
421+
if msg, ok := <-stream; ok {
422+
_, err := w.Write(msg)
423+
if err != nil {
424+
return false
425+
}
426+
return true
427+
}
428+
return false
341429
})
342430
} else {
343431
// Non-streaming request
@@ -361,6 +449,13 @@ func httpProxyHandler(c *gin.Context, endpointURL string, inputBytes []byte) {
361449
return
362450
}
363451

452+
// Check for empty response
453+
if len(body) == 0 {
454+
log.Printf("[ERROR] Empty response from ECS endpoint")
455+
c.JSON(500, gin.H{"error": "Empty response from ECS endpoint"})
456+
return
457+
}
458+
364459
// Validate JSON response if content type is application/json
365460
contentType := resp.Header.Get("Content-Type")
366461
if strings.Contains(contentType, "application/json") {
@@ -375,6 +470,7 @@ func httpProxyHandler(c *gin.Context, endpointURL string, inputBytes []byte) {
375470
}
376471
return
377472
}
473+
// log.Printf("[DEBUG] ECS response length: %d bytes", len(body))
378474
}
379475

380476
// Copy response headers
@@ -634,8 +730,10 @@ func requestHandler(c *gin.Context) {
634730
eventStream := resp.GetStream()
635731
defer eventStream.Close()
636732

637-
// Buffer for accumulating partial chunks
733+
// Enhanced buffer for accumulating partial chunks with better handling
638734
var buffer strings.Builder
735+
var lastValidJSON string
736+
chunkCount := 0
639737

640738
for event := range eventStream.Events() {
641739
switch e := event.(type) {
@@ -646,13 +744,14 @@ func requestHandler(c *gin.Context) {
646744
}
647745

648746
chunk := string(e.Bytes)
649-
// log.Printf("[DEBUG] Received raw chunk: %s", chunk)
747+
chunkCount++
748+
// log.Printf("[DEBUG] Received chunk #%d: %s", chunkCount, chunk)
650749

651750
// Add chunk to buffer
652751
buffer.WriteString(chunk)
653752
bufferContent := buffer.String()
654753

655-
// Process complete lines from buffer (SSE data should be line-based)
754+
// Process complete lines from buffer with enhanced validation
656755
for strings.Contains(bufferContent, "\n") {
657756
lines := strings.SplitN(bufferContent, "\n", 2)
658757
if len(lines) < 2 {
@@ -661,23 +760,33 @@ func requestHandler(c *gin.Context) {
661760

662761
line := strings.TrimSpace(lines[0])
663762
if line != "" {
664-
// Validate JSON
763+
// Enhanced JSON validation with recovery mechanisms
665764
if json.Valid([]byte(line)) {
765+
lastValidJSON = line
666766
// Format as proper SSE event and send
667767
formattedChunk := "data: " + line
668768

669-
// Check for finish_reason=stop to end stream
769+
// Check for finish_reason to end stream (including length limit)
670770
if strings.Contains(line, `"finish_reason":"stop"`) ||
671-
strings.Contains(line, `"finish_reason": "stop"`) {
672-
// log.Printf("[DEBUG] Detected finish_reason=stop, ending stream")
771+
strings.Contains(line, `"finish_reason": "stop"`) ||
772+
strings.Contains(line, `"finish_reason":"length"`) ||
773+
strings.Contains(line, `"finish_reason": "length"`) {
774+
// log.Printf("[DEBUG] Detected finish_reason, ending stream")
673775
stream <- []byte(formattedChunk + "\n\n")
674776
return // Exit the goroutine completely
675777
}
676778

677779
// Forward as properly formatted SSE event
678780
stream <- []byte(formattedChunk + "\n\n")
679781
} else {
680-
log.Printf("[WARNING] Invalid JSON line: %s", line)
782+
log.Printf("[WARNING] Invalid JSON line (chunk #%d): %s", chunkCount, line)
783+
// Try to recover by checking if it's a partial JSON that can be completed
784+
if isPartialJSON(line) {
785+
log.Printf("[WARNING] Detected partial JSON, attempting recovery")
786+
// Don't forward partial JSON, wait for more data
787+
} else {
788+
log.Printf("[WARNING] Completely invalid JSON, skipping")
789+
}
681790
}
682791
}
683792

@@ -688,16 +797,25 @@ func requestHandler(c *gin.Context) {
688797
}
689798

690799
case *sagemakerruntime.InternalStreamFailure:
800+
log.Printf("[ERROR] SageMaker stream failure: %v", e.Error())
691801
stream <- []byte(`data: {"error": "` + e.Error() + `"}` + "\n\n")
692802
return
693803
}
694804
}
695805

696-
// Process any remaining data in buffer
806+
// Process any remaining data in buffer with enhanced validation
697807
if buffer.Len() > 0 {
698808
remaining := strings.TrimSpace(buffer.String())
699-
if remaining != "" && json.Valid([]byte(remaining)) {
700-
stream <- []byte("data: " + remaining + "\n\n")
809+
if remaining != "" {
810+
if json.Valid([]byte(remaining)) {
811+
stream <- []byte("data: " + remaining + "\n\n")
812+
} else {
813+
log.Printf("[WARNING] Discarding invalid JSON remainder (total chunks: %d): %s", chunkCount, remaining)
814+
// If we have a last valid JSON and this looks like a partial, try to recover
815+
if isPartialJSON(remaining) && lastValidJSON != "" {
816+
log.Printf("[WARNING] Final chunk appears to be partial JSON, stream may have been truncated")
817+
}
818+
}
701819
}
702820
}
703821

0 commit comments

Comments
 (0)