Skip to content

Commit 16e8cff

Browse files
committed
Add file part support to chat message structure
Introduces ChatMessagePartFile struct and ChatMessagePartTypeFile constant to support file attachments in chat messages. Updates ChatMessagePart to include file parts and adds comprehensive tests for serialization, deserialization, and constant definitions.
1 parent c4273cb commit 16e8cff

File tree

2 files changed

+120
-2
lines changed

2 files changed

+120
-2
lines changed

chat.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,17 +81,26 @@ type ChatMessageImageURL struct {
8181
Detail ImageURLDetail `json:"detail,omitempty"`
8282
}
8383

84+
// ChatMessagePartFile is a placeholder for file parts in chat messages.
85+
type ChatMessagePartFile struct {
86+
FileID string `json:"file_id,omitempty"`
87+
FileName string `json:"filename,omitempty"`
88+
FileData string `json:"file_data,omitempty"` // Base64 encoded file data
89+
}
90+
8491
type ChatMessagePartType string
8592

8693
const (
8794
ChatMessagePartTypeText ChatMessagePartType = "text"
8895
ChatMessagePartTypeImageURL ChatMessagePartType = "image_url"
96+
ChatMessagePartTypeFile ChatMessagePartType = "file"
8997
)
9098

9199
type ChatMessagePart struct {
92100
Type ChatMessagePartType `json:"type,omitempty"`
93101
Text string `json:"text,omitempty"`
94102
ImageURL *ChatMessageImageURL `json:"image_url,omitempty"`
103+
File *ChatMessagePartFile `json:"file,omitempty"`
95104
}
96105

97106
type ChatCompletionMessage struct {

chat_test.go

Lines changed: 111 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -677,6 +677,14 @@ func TestMultipartChatCompletions(t *testing.T) {
677677
Detail: openai.ImageURLDetailLow,
678678
},
679679
},
680+
{
681+
Type: openai.ChatMessagePartTypeFile,
682+
File: &openai.ChatMessagePartFile{
683+
FileID: "file-123",
684+
FileName: "test.txt",
685+
FileData: "dGVzdCBmaWxlIGNvbnRlbnQ=", // base64 encoded "test file content"
686+
},
687+
},
680688
},
681689
},
682690
},
@@ -687,7 +695,8 @@ func TestMultipartChatCompletions(t *testing.T) {
687695
func TestMultipartChatMessageSerialization(t *testing.T) {
688696
jsonText := `[{"role":"system","content":"system-message"},` +
689697
`{"role":"user","content":[{"type":"text","text":"nice-text"},` +
690-
`{"type":"image_url","image_url":{"url":"URL","detail":"high"}}]}]`
698+
`{"type":"image_url","image_url":{"url":"URL","detail":"high"}},` +
699+
`{"type":"file","file":{"file_id":"file-123","filename":"test.txt","file_data":"dGVzdA=="}}]}]`
691700

692701
var msgs []openai.ChatCompletionMessage
693702
err := json.Unmarshal([]byte(jsonText), &msgs)
@@ -700,7 +709,7 @@ func TestMultipartChatMessageSerialization(t *testing.T) {
700709
if msgs[0].Role != "system" || msgs[0].Content != "system-message" || msgs[0].MultiContent != nil {
701710
t.Errorf("invalid user message: %v", msgs[0])
702711
}
703-
if msgs[1].Role != "user" || msgs[1].Content != "" || len(msgs[1].MultiContent) != 2 {
712+
if msgs[1].Role != "user" || msgs[1].Content != "" || len(msgs[1].MultiContent) != 3 {
704713
t.Errorf("invalid user message")
705714
}
706715
parts := msgs[1].MultiContent
@@ -710,6 +719,9 @@ func TestMultipartChatMessageSerialization(t *testing.T) {
710719
if parts[1].Type != "image_url" || parts[1].ImageURL.URL != "URL" || parts[1].ImageURL.Detail != "high" {
711720
t.Errorf("invalid image_url part")
712721
}
722+
if parts[2].Type != "file" || parts[2].File.FileID != "file-123" || parts[2].File.FileName != "test.txt" || parts[2].File.FileData != "dGVzdA==" {
723+
t.Errorf("invalid file part: %v", parts[2])
724+
}
713725

714726
s, err := json.Marshal(msgs)
715727
if err != nil {
@@ -756,6 +768,103 @@ func TestMultipartChatMessageSerialization(t *testing.T) {
756768
}
757769
}
758770

771+
func TestChatMessagePartFile(t *testing.T) {
772+
// Test file part with FileID
773+
filePart := openai.ChatMessagePart{
774+
Type: openai.ChatMessagePartTypeFile,
775+
File: &openai.ChatMessagePartFile{
776+
FileID: "file-abc123",
777+
},
778+
}
779+
780+
// Test serialization
781+
data, err := json.Marshal(filePart)
782+
if err != nil {
783+
t.Fatalf("Expected no error: %s", err)
784+
}
785+
786+
expected := `{"type":"file","file":{"file_id":"file-abc123"}}`
787+
result := strings.ReplaceAll(string(data), " ", "")
788+
if result != expected {
789+
t.Errorf("Expected %s, got %s", expected, result)
790+
}
791+
792+
// Test deserialization
793+
var parsedPart openai.ChatMessagePart
794+
err = json.Unmarshal(data, &parsedPart)
795+
if err != nil {
796+
t.Fatalf("Expected no error: %s", err)
797+
}
798+
799+
if parsedPart.Type != openai.ChatMessagePartTypeFile {
800+
t.Errorf("Expected type %s, got %s", openai.ChatMessagePartTypeFile, parsedPart.Type)
801+
}
802+
if parsedPart.File == nil {
803+
t.Fatal("Expected File to be non-nil")
804+
}
805+
if parsedPart.File.FileID != "file-abc123" {
806+
t.Errorf("Expected FileID %s, got %s", "file-abc123", parsedPart.File.FileID)
807+
}
808+
809+
// Test file part with all fields
810+
filePartComplete := openai.ChatMessagePart{
811+
Type: openai.ChatMessagePartTypeFile,
812+
File: &openai.ChatMessagePartFile{
813+
FileID: "file-xyz789",
814+
FileName: "document.pdf",
815+
FileData: "JVBERi0xLjQK", // base64 for "%PDF-1.4\n"
816+
},
817+
}
818+
819+
data, err = json.Marshal(filePartComplete)
820+
if err != nil {
821+
t.Fatalf("Expected no error: %s", err)
822+
}
823+
824+
expected = `{"type":"file","file":{"file_id":"file-xyz789","filename":"document.pdf","file_data":"JVBERi0xLjQK"}}`
825+
result = strings.ReplaceAll(string(data), " ", "")
826+
if result != expected {
827+
t.Errorf("Expected %s, got %s", expected, result)
828+
}
829+
830+
// Test deserialization of complete file part
831+
var parsedCompleteFile openai.ChatMessagePart
832+
err = json.Unmarshal(data, &parsedCompleteFile)
833+
if err != nil {
834+
t.Fatalf("Expected no error: %s", err)
835+
}
836+
837+
if parsedCompleteFile.File.FileID != "file-xyz789" {
838+
t.Errorf("Expected FileID %s, got %s", "file-xyz789", parsedCompleteFile.File.FileID)
839+
}
840+
if parsedCompleteFile.File.FileName != "document.pdf" {
841+
t.Errorf("Expected FileName %s, got %s", "document.pdf", parsedCompleteFile.File.FileName)
842+
}
843+
if parsedCompleteFile.File.FileData != "JVBERi0xLjQK" {
844+
t.Errorf("Expected FileData %s, got %s", "JVBERi0xLjQK", parsedCompleteFile.File.FileData)
845+
}
846+
}
847+
848+
func TestChatMessagePartTypeConstants(t *testing.T) {
849+
// Test that the new file constant is properly defined
850+
if openai.ChatMessagePartTypeFile != "file" {
851+
t.Errorf("Expected ChatMessagePartTypeFile to be 'file', got %s", openai.ChatMessagePartTypeFile)
852+
}
853+
854+
// Test all part type constants
855+
expectedTypes := map[openai.ChatMessagePartType]string{
856+
openai.ChatMessagePartTypeText: "text",
857+
openai.ChatMessagePartTypeImageURL: "image_url",
858+
openai.ChatMessagePartTypeFile: "file",
859+
}
860+
861+
for constant, expected := range expectedTypes {
862+
if string(constant) != expected {
863+
t.Errorf("Expected %s to be %s, got %s", constant, expected, string(constant))
864+
}
865+
}
866+
}
867+
759868
// handleChatCompletionEndpoint Handles the ChatGPT completion endpoint by the test server.
760869
func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
761870
var err error

0 commit comments

Comments
 (0)