36
36
IdleConnTimeout : 90 * time .Second ,
37
37
DisableCompression : false ,
38
38
MaxIdleConnsPerHost : 20 ,
39
+ // Add buffer sizes to handle large responses
40
+ ReadBufferSize : 32 * 1024 , // 32KB read buffer
41
+ WriteBufferSize : 32 * 1024 , // 32KB write buffer
39
42
}
40
43
httpClient = & http.Client {
41
44
Transport : httpTransport ,
@@ -80,12 +83,24 @@ func main() {
80
83
81
84
// Create router with ALB support
82
85
router := gin .Default ()
86
+
87
+ // Add request logging middleware
83
88
router .Use (func (c * gin.Context ) {
89
+ start := time .Now ()
90
+
84
91
// Trust X-Forwarded-For header
85
92
if forwardedFor := c .GetHeader ("X-Forwarded-For" ); forwardedFor != "" {
86
93
c .Request .RemoteAddr = strings .Split (forwardedFor , "," )[0 ]
87
94
}
95
+
88
96
c .Next ()
97
+
98
+ // Log request details after processing
99
+ duration := time .Since (start )
100
+ if logLevel == "DEBUG" {
101
+ log .Printf ("[DEBUG] %s %s - Status: %d, Duration: %v, Size: %d bytes" ,
102
+ c .Request .Method , c .Request .URL .Path , c .Writer .Status (), duration , c .Writer .Size ())
103
+ }
89
104
})
90
105
91
106
// Health checks (ALB requires / and /health)
@@ -276,6 +291,9 @@ func httpProxyHandler(c *gin.Context, endpointURL string, inputBytes []byte) {
276
291
IdleConnTimeout : 90 * time .Second ,
277
292
DisableCompression : false ,
278
293
MaxIdleConnsPerHost : 20 ,
294
+ // Add buffer sizes to handle large responses
295
+ ReadBufferSize : 32 * 1024 , // 32KB read buffer
296
+ WriteBufferSize : 32 * 1024 , // 32KB write buffer
279
297
},
280
298
}
281
299
@@ -336,14 +354,37 @@ func httpProxyHandler(c *gin.Context, endpointURL string, inputBytes []byte) {
336
354
return
337
355
}
338
356
339
- // Forward the response
357
+ // Forward the response with better error handling
340
358
body , err := io .ReadAll (resp .Body )
341
359
if err != nil {
342
360
c .JSON (500 , gin.H {"error" : fmt .Sprintf ("Failed to read response body: %v" , err )})
343
361
return
344
362
}
345
363
346
- c .Data (resp .StatusCode , resp .Header .Get ("Content-Type" ), body )
364
+ // Validate JSON response if content type is application/json
365
+ contentType := resp .Header .Get ("Content-Type" )
366
+ if strings .Contains (contentType , "application/json" ) {
367
+ if ! json .Valid (body ) {
368
+ bodyStr := string (body )
369
+ log .Printf ("[ERROR] Invalid JSON response from ECS endpoint: %s" , bodyStr )
370
+ if isPartialJSON (bodyStr ) {
371
+ log .Printf ("[ERROR] Detected partial JSON response from ECS, likely truncated" )
372
+ c .JSON (500 , gin.H {"error" : "Response was truncated, please try again" })
373
+ } else {
374
+ c .JSON (500 , gin.H {"error" : "Invalid JSON response from backend" })
375
+ }
376
+ return
377
+ }
378
+ }
379
+
380
+ // Copy response headers
381
+ for k , v := range resp .Header {
382
+ if k != "Content-Length" { // Let Gin handle Content-Length
383
+ c .Header (k , strings .Join (v , "," ))
384
+ }
385
+ }
386
+
387
+ c .Data (resp .StatusCode , contentType , body )
347
388
}
348
389
}
349
390
@@ -363,6 +404,52 @@ func isPartialJSON(s string) bool {
363
404
return openBraces > closeBraces
364
405
}
365
406
407
+ // findCompleteJSON finds the end position of the first complete JSON object in the string
408
+ // Returns -1 if no complete JSON object is found
409
+ func findCompleteJSON (s string ) int {
410
+ if len (s ) == 0 {
411
+ return - 1
412
+ }
413
+
414
+ // Track brace/bracket depth
415
+ depth := 0
416
+ inString := false
417
+ escaped := false
418
+
419
+ for i , char := range s {
420
+ if escaped {
421
+ escaped = false
422
+ continue
423
+ }
424
+
425
+ switch char {
426
+ case '\\' :
427
+ if inString {
428
+ escaped = true
429
+ }
430
+ case '"' :
431
+ if ! escaped {
432
+ inString = ! inString
433
+ }
434
+ case '{' , '[' :
435
+ if ! inString {
436
+ depth ++
437
+ }
438
+ case '}' , ']' :
439
+ if ! inString {
440
+ depth --
441
+ if depth == 0 {
442
+ // Found complete JSON object, return position after the closing brace
443
+ return i + 1
444
+ }
445
+ }
446
+ }
447
+ }
448
+
449
+ // No complete JSON object found
450
+ return - 1
451
+ }
452
+
366
453
func getEndpointForModel (modelKey string ) (string , error ) {
367
454
// Only accept model_id/model_tag format
368
455
if ! strings .Contains (modelKey , "/" ) {
@@ -547,6 +634,9 @@ func requestHandler(c *gin.Context) {
547
634
eventStream := resp .GetStream ()
548
635
defer eventStream .Close ()
549
636
637
+ // Buffer for accumulating partial chunks
638
+ var buffer strings.Builder
639
+
550
640
for event := range eventStream .Events () {
551
641
switch e := event .(type ) {
552
642
case * sagemakerruntime.PayloadPart :
@@ -555,27 +645,62 @@ func requestHandler(c *gin.Context) {
555
645
continue
556
646
}
557
647
558
- chunk := string (e .Bytes )
559
- // log.Printf("[DEBUG] Received chunk: %s", chunk)
560
-
561
- // Format as proper SSE event
562
- formattedChunk := "data: " + chunk
563
-
564
- // Check for finish_reason=stop to end stream
565
- if strings .Contains (chunk , `"finish_reason":"stop"` ) ||
566
- strings .Contains (chunk , `"finish_reason": "stop"` ) {
567
- // log.Printf("[DEBUG] Detected finish_reason=stop, ending stream")
568
- stream <- []byte (formattedChunk + "\n \n " )
569
- return // Exit the goroutine completely
570
- }
648
+ chunk := string (e .Bytes )
649
+ // log.Printf("[DEBUG] Received raw chunk: %s", chunk)
650
+
651
+ // Add chunk to buffer
652
+ buffer .WriteString (chunk )
653
+ bufferContent := buffer .String ()
654
+
655
+ // Process complete lines from buffer (SSE data should be line-based)
656
+ for strings .Contains (bufferContent , "\n " ) {
657
+ lines := strings .SplitN (bufferContent , "\n " , 2 )
658
+ if len (lines ) < 2 {
659
+ break
660
+ }
661
+
662
+ line := strings .TrimSpace (lines [0 ])
663
+ if line != "" {
664
+ // Validate JSON
665
+ if json .Valid ([]byte (line )) {
666
+ // Format as proper SSE event and send
667
+ formattedChunk := "data: " + line
668
+
669
+ // Check for finish_reason=stop to end stream
670
+ if strings .Contains (line , `"finish_reason":"stop"` ) ||
671
+ strings .Contains (line , `"finish_reason": "stop"` ) {
672
+ // log.Printf("[DEBUG] Detected finish_reason=stop, ending stream")
673
+ stream <- []byte (formattedChunk + "\n \n " )
674
+ return // Exit the goroutine completely
675
+ }
676
+
677
+ // Forward as properly formatted SSE event
678
+ stream <- []byte (formattedChunk + "\n \n " )
679
+ } else {
680
+ log .Printf ("[WARNING] Invalid JSON line: %s" , line )
681
+ }
682
+ }
683
+
684
+ // Update buffer with remaining content
685
+ bufferContent = lines [1 ]
686
+ buffer .Reset ()
687
+ buffer .WriteString (bufferContent )
688
+ }
571
689
572
- // Forward as properly formatted SSE event
573
- stream <- []byte (formattedChunk + "\n \n " )
574
690
case * sagemakerruntime.InternalStreamFailure :
575
691
stream <- []byte (`data: {"error": "` + e .Error () + `"}` + "\n \n " )
576
692
return
577
693
}
578
694
}
695
+
696
+ // Process any remaining data in buffer
697
+ if buffer .Len () > 0 {
698
+ remaining := strings .TrimSpace (buffer .String ())
699
+ if remaining != "" && json .Valid ([]byte (remaining )) {
700
+ stream <- []byte ("data: " + remaining + "\n \n " )
701
+ }
702
+ }
703
+
579
704
// Send final done message
580
705
stream <- []byte ("data: [DONE]\n \n " )
581
706
}()
@@ -592,8 +717,11 @@ func requestHandler(c *gin.Context) {
592
717
})
593
718
594
719
} else {
595
- // Non-streaming request
596
- output , err := sagemakerClient .InvokeEndpoint (& sagemakerruntime.InvokeEndpointInput {
720
+ // Non-streaming request with timeout context
721
+ ctx , cancel := context .WithTimeout (c .Request .Context (), 10 * time .Minute )
722
+ defer cancel ()
723
+
724
+ output , err := sagemakerClient .InvokeEndpointWithContext (ctx , & sagemakerruntime.InvokeEndpointInput {
597
725
EndpointName : aws .String (endpointAddress ),
598
726
ContentType : aws .String ("application/json" ),
599
727
Body : modifiedBytes ,
@@ -611,8 +739,24 @@ func requestHandler(c *gin.Context) {
611
739
return
612
740
}
613
741
614
- // log.Printf("[DEBUG] SageMaker response: %s", string(output.Body))
615
- // Forward raw SageMaker response
742
+ // Validate JSON before forwarding to catch truncation issues
743
+ responseStr := string (output .Body )
744
+ if ! json .Valid (output .Body ) {
745
+ log .Printf ("[ERROR] Invalid JSON response from SageMaker: %s" , responseStr )
746
+ // Try to detect if it's a partial JSON
747
+ if isPartialJSON (responseStr ) {
748
+ log .Printf ("[ERROR] Detected partial JSON response, likely truncated" )
749
+ c .JSON (500 , gin.H {"error" : "Response was truncated, please try again" })
750
+ } else {
751
+ c .JSON (500 , gin.H {"error" : "Invalid JSON response from backend" })
752
+ }
753
+ return
754
+ }
755
+
756
+ // log.Printf("[DEBUG] SageMaker response length: %d bytes", len(output.Body))
757
+ // Set proper headers and forward response
758
+ c .Header ("Content-Type" , "application/json" )
759
+ c .Header ("Content-Length" , fmt .Sprintf ("%d" , len (output .Body )))
616
760
c .Data (200 , "application/json" , output .Body )
617
761
}
618
762
} else if endpointType == "ecs" {
0 commit comments