From 1421a0833e2fa706f9eaf7a887fe15bb09aea7c8 Mon Sep 17 00:00:00 2001 From: Megumin Date: Wed, 2 Apr 2025 21:13:38 +0800 Subject: [PATCH 1/6] feat: temperature pointer --- api_integration_test.go | 6 +++--- assistant_test.go | 4 ++-- audio.go | 2 +- audio_api_test.go | 6 +++--- batch_test.go | 4 ++-- chat.go | 2 +- chat_stream_test.go | 4 ++-- chat_test.go | 22 +++++++++++++--------- client.go | 2 +- client_test.go | 4 ++-- completion_test.go | 4 ++-- config_test.go | 2 +- edits_test.go | 4 ++-- embeddings_test.go | 4 ++-- engines_test.go | 4 ++-- error_test.go | 2 +- example_test.go | 2 +- examples/chatbot/main.go | 2 +- examples/completion-with-tool/main.go | 4 ++-- examples/completion/main.go | 2 +- examples/images/main.go | 2 +- examples/voice-to-text/main.go | 2 +- files_api_test.go | 4 ++-- files_test.go | 4 ++-- fine_tunes_test.go | 4 ++-- fine_tuning_job_test.go | 4 ++-- go.mod | 2 +- image_api_test.go | 4 ++-- image_test.go | 4 ++-- internal/error_accumulator_test.go | 4 ++-- internal/form_builder_test.go | 2 +- internal/test/helpers.go | 2 +- jsonschema/json_test.go | 2 +- jsonschema/validate_test.go | 2 +- messages_test.go | 6 +++--- models_test.go | 4 ++-- moderation_test.go | 4 ++-- openai_test.go | 4 ++-- reasoning_validator.go | 2 +- run_test.go | 4 ++-- speech_test.go | 6 +++--- stream_reader.go | 2 +- stream_reader_test.go | 6 +++--- stream_test.go | 4 ++-- thread_test.go | 4 ++-- vector_store_test.go | 4 ++-- 46 files changed, 91 insertions(+), 87 deletions(-) diff --git a/api_integration_test.go b/api_integration_test.go index 7828d9451..9f55c56e5 100644 --- a/api_integration_test.go +++ b/api_integration_test.go @@ -10,9 +10,9 @@ import ( "os" "testing" - "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test/checks" - "github.com/sashabaranov/go-openai/jsonschema" + "github.com/meguminnnnnnnnn/go-openai" + "github.com/meguminnnnnnnnn/go-openai/internal/test/checks" + "github.com/meguminnnnnnnnn/go-openai/jsonschema" ) func TestAPI(t *testing.T) { diff --git a/assistant_test.go b/assistant_test.go index 40de0e50f..7ae0b5a2e 100644 --- a/assistant_test.go +++ b/assistant_test.go @@ -3,8 +3,8 @@ package openai_test import ( "context" - openai "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test/checks" + openai "github.com/meguminnnnnnnnn/go-openai" + "github.com/meguminnnnnnnnn/go-openai/internal/test/checks" "encoding/json" "fmt" diff --git a/audio.go b/audio.go index f321f93d6..636b897eb 100644 --- a/audio.go +++ b/audio.go @@ -8,7 +8,7 @@ import ( "net/http" "os" - utils "github.com/sashabaranov/go-openai/internal" + utils "github.com/meguminnnnnnnnn/go-openai/internal" ) // Whisper Defines the models provided by OpenAI to use when processing audio with OpenAI. diff --git a/audio_api_test.go b/audio_api_test.go index 6c6a35643..af3e12493 100644 --- a/audio_api_test.go +++ b/audio_api_test.go @@ -12,9 +12,9 @@ import ( "strings" "testing" - "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test" - "github.com/sashabaranov/go-openai/internal/test/checks" + "github.com/meguminnnnnnnnn/go-openai" + "github.com/meguminnnnnnnnn/go-openai/internal/test" + "github.com/meguminnnnnnnnn/go-openai/internal/test/checks" ) // TestAudio Tests the transcription and translation endpoints of the API using the mocked server. diff --git a/batch_test.go b/batch_test.go index f4714f4eb..9504944b4 100644 --- a/batch_test.go +++ b/batch_test.go @@ -7,8 +7,8 @@ import ( "reflect" "testing" - "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test/checks" + "github.com/meguminnnnnnnnn/go-openai" + "github.com/meguminnnnnnnnn/go-openai/internal/test/checks" ) func TestUploadBatchFile(t *testing.T) { diff --git a/chat.go b/chat.go index c8a3e81b3..3621a48c8 100644 --- a/chat.go +++ b/chat.go @@ -233,7 +233,7 @@ type ChatCompletionRequest struct { // MaxCompletionTokens An upper bound for the number of tokens that can be generated for a completion, // including visible output tokens and reasoning tokens https://platform.openai.com/docs/guides/reasoning MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` - Temperature float32 `json:"temperature,omitempty"` + Temperature *float32 `json:"temperature,omitempty"` TopP float32 `json:"top_p,omitempty"` N int `json:"n,omitempty"` Stream bool `json:"stream,omitempty"` diff --git a/chat_stream_test.go b/chat_stream_test.go index eabb0f3a2..0e19b44d7 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -10,8 +10,8 @@ import ( "strconv" "testing" - "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test/checks" + "github.com/meguminnnnnnnnn/go-openai" + "github.com/meguminnnnnnnnn/go-openai/internal/test/checks" ) func TestChatCompletionsStreamWrongModel(t *testing.T) { diff --git a/chat_test.go b/chat_test.go index 514706c96..673390087 100644 --- a/chat_test.go +++ b/chat_test.go @@ -12,9 +12,9 @@ import ( "testing" "time" - "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test/checks" - "github.com/sashabaranov/go-openai/jsonschema" + "github.com/meguminnnnnnnnn/go-openai" + "github.com/meguminnnnnnnnn/go-openai/internal/test/checks" + "github.com/meguminnnnnnnnn/go-openai/jsonschema" ) const ( @@ -91,6 +91,10 @@ func TestO1ModelsChatCompletionsDeprecatedFields(t *testing.T) { } } +func ptrOf[T any](v T) *T { + return &v +} + func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) { tests := []struct { name string @@ -119,7 +123,7 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) { Role: openai.ChatMessageRoleAssistant, }, }, - Temperature: float32(2), + Temperature: ptrOf(float32(2)), }, expectedError: openai.ErrReasoningModelLimitationsOther, }, @@ -136,7 +140,7 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) { Role: openai.ChatMessageRoleAssistant, }, }, - Temperature: float32(1), + Temperature: ptrOf(float32(1)), TopP: float32(0.1), }, expectedError: openai.ErrReasoningModelLimitationsOther, @@ -154,7 +158,7 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) { Role: openai.ChatMessageRoleAssistant, }, }, - Temperature: float32(1), + Temperature: ptrOf(float32(1)), TopP: float32(1), N: 2, }, @@ -239,7 +243,7 @@ func TestO3ModelsChatCompletionsBetaLimitations(t *testing.T) { Role: openai.ChatMessageRoleAssistant, }, }, - Temperature: float32(2), + Temperature: ptrOf(float32(2)), }, expectedError: openai.ErrReasoningModelLimitationsOther, }, @@ -256,7 +260,7 @@ func TestO3ModelsChatCompletionsBetaLimitations(t *testing.T) { Role: openai.ChatMessageRoleAssistant, }, }, - Temperature: float32(1), + Temperature: ptrOf(float32(1)), TopP: float32(0.1), }, expectedError: openai.ErrReasoningModelLimitationsOther, @@ -274,7 +278,7 @@ func TestO3ModelsChatCompletionsBetaLimitations(t *testing.T) { Role: openai.ChatMessageRoleAssistant, }, }, - Temperature: float32(1), + Temperature: ptrOf(float32(1)), TopP: float32(1), N: 2, }, diff --git a/client.go b/client.go index cef375348..373d53f21 100644 --- a/client.go +++ b/client.go @@ -10,7 +10,7 @@ import ( "net/url" "strings" - utils "github.com/sashabaranov/go-openai/internal" + utils "github.com/meguminnnnnnnnn/go-openai/internal" ) // Client is OpenAI GPT-3 API client. diff --git a/client_test.go b/client_test.go index 321971445..e333759df 100644 --- a/client_test.go +++ b/client_test.go @@ -10,8 +10,8 @@ import ( "reflect" "testing" - "github.com/sashabaranov/go-openai/internal/test" - "github.com/sashabaranov/go-openai/internal/test/checks" + "github.com/meguminnnnnnnnn/go-openai/internal/test" + "github.com/meguminnnnnnnnn/go-openai/internal/test/checks" ) var errTestRequestBuilderFailed = errors.New("test request builder failed") diff --git a/completion_test.go b/completion_test.go index 27e2d150e..63c83dce3 100644 --- a/completion_test.go +++ b/completion_test.go @@ -12,8 +12,8 @@ import ( "testing" "time" - "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test/checks" + "github.com/meguminnnnnnnnn/go-openai" + "github.com/meguminnnnnnnnn/go-openai/internal/test/checks" ) func TestCompletionsWrongModel(t *testing.T) { diff --git a/config_test.go b/config_test.go index 960230804..a86e2f232 100644 --- a/config_test.go +++ b/config_test.go @@ -3,7 +3,7 @@ package openai_test import ( "testing" - "github.com/sashabaranov/go-openai" + "github.com/meguminnnnnnnnn/go-openai" ) func TestGetAzureDeploymentByModel(t *testing.T) { diff --git a/edits_test.go b/edits_test.go index d2a6db40d..1898d77ce 100644 --- a/edits_test.go +++ b/edits_test.go @@ -9,8 +9,8 @@ import ( "testing" "time" - "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test/checks" + "github.com/meguminnnnnnnnn/go-openai" + "github.com/meguminnnnnnnnn/go-openai/internal/test/checks" ) // TestEdits Tests the edits endpoint of the API using the mocked server. diff --git a/embeddings_test.go b/embeddings_test.go index 438978169..192e3ddcb 100644 --- a/embeddings_test.go +++ b/embeddings_test.go @@ -11,8 +11,8 @@ import ( "reflect" "testing" - "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test/checks" + "github.com/meguminnnnnnnnn/go-openai" + "github.com/meguminnnnnnnnn/go-openai/internal/test/checks" ) func TestEmbedding(t *testing.T) { diff --git a/engines_test.go b/engines_test.go index d26aa5541..90b7973be 100644 --- a/engines_test.go +++ b/engines_test.go @@ -7,8 +7,8 @@ import ( "net/http" "testing" - "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test/checks" + "github.com/meguminnnnnnnnn/go-openai" + "github.com/meguminnnnnnnnn/go-openai/internal/test/checks" ) // TestGetEngine Tests the retrieve engine endpoint of the API using the mocked server. diff --git a/error_test.go b/error_test.go index 48cbe4f29..1d8fe5e2d 100644 --- a/error_test.go +++ b/error_test.go @@ -6,7 +6,7 @@ import ( "reflect" "testing" - "github.com/sashabaranov/go-openai" + "github.com/meguminnnnnnnnn/go-openai" ) func TestAPIErrorUnmarshalJSON(t *testing.T) { diff --git a/example_test.go b/example_test.go index 5910ffb84..1a55952b7 100644 --- a/example_test.go +++ b/example_test.go @@ -11,7 +11,7 @@ import ( "net/url" "os" - "github.com/sashabaranov/go-openai" + "github.com/meguminnnnnnnnn/go-openai" ) func Example() { diff --git a/examples/chatbot/main.go b/examples/chatbot/main.go index ad41e957d..e4895dac4 100644 --- a/examples/chatbot/main.go +++ b/examples/chatbot/main.go @@ -6,7 +6,7 @@ import ( "fmt" "os" - "github.com/sashabaranov/go-openai" + "github.com/meguminnnnnnnnn/go-openai" ) func main() { diff --git a/examples/completion-with-tool/main.go b/examples/completion-with-tool/main.go index 26126e41b..181066dba 100644 --- a/examples/completion-with-tool/main.go +++ b/examples/completion-with-tool/main.go @@ -5,8 +5,8 @@ import ( "fmt" "os" - "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/jsonschema" + "github.com/meguminnnnnnnnn/go-openai" + "github.com/meguminnnnnnnnn/go-openai/jsonschema" ) func main() { diff --git a/examples/completion/main.go b/examples/completion/main.go index 8c5cbd5ca..b1b980f78 100644 --- a/examples/completion/main.go +++ b/examples/completion/main.go @@ -5,7 +5,7 @@ import ( "fmt" "os" - "github.com/sashabaranov/go-openai" + "github.com/meguminnnnnnnnn/go-openai" ) func main() { diff --git a/examples/images/main.go b/examples/images/main.go index 5ee649d22..eca84afd9 100644 --- a/examples/images/main.go +++ b/examples/images/main.go @@ -5,7 +5,7 @@ import ( "fmt" "os" - "github.com/sashabaranov/go-openai" + "github.com/meguminnnnnnnnn/go-openai" ) func main() { diff --git a/examples/voice-to-text/main.go b/examples/voice-to-text/main.go index 713e748e1..d1ddc4fd1 100644 --- a/examples/voice-to-text/main.go +++ b/examples/voice-to-text/main.go @@ -6,7 +6,7 @@ import ( "fmt" "os" - "github.com/sashabaranov/go-openai" + "github.com/meguminnnnnnnnn/go-openai" ) func main() { diff --git a/files_api_test.go b/files_api_test.go index aa4fda458..22245f0b4 100644 --- a/files_api_test.go +++ b/files_api_test.go @@ -12,8 +12,8 @@ import ( "testing" "time" - "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test/checks" + "github.com/meguminnnnnnnnn/go-openai" + "github.com/meguminnnnnnnnn/go-openai/internal/test/checks" ) func TestFileBytesUpload(t *testing.T) { diff --git a/files_test.go b/files_test.go index 3c1b99fb4..1960e2394 100644 --- a/files_test.go +++ b/files_test.go @@ -7,8 +7,8 @@ import ( "os" "testing" - utils "github.com/sashabaranov/go-openai/internal" - "github.com/sashabaranov/go-openai/internal/test/checks" + utils "github.com/meguminnnnnnnnn/go-openai/internal" + "github.com/meguminnnnnnnnn/go-openai/internal/test/checks" ) func TestFileBytesUploadWithFailingFormBuilder(t *testing.T) { diff --git a/fine_tunes_test.go b/fine_tunes_test.go index 2ab6817f7..39bd8eea9 100644 --- a/fine_tunes_test.go +++ b/fine_tunes_test.go @@ -7,8 +7,8 @@ import ( "net/http" "testing" - "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test/checks" + "github.com/meguminnnnnnnnn/go-openai" + "github.com/meguminnnnnnnnn/go-openai/internal/test/checks" ) const testFineTuneID = "fine-tune-id" diff --git a/fine_tuning_job_test.go b/fine_tuning_job_test.go index 5f63ef24c..892dff7c9 100644 --- a/fine_tuning_job_test.go +++ b/fine_tuning_job_test.go @@ -7,8 +7,8 @@ import ( "net/http" "testing" - "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test/checks" + "github.com/meguminnnnnnnnn/go-openai" + "github.com/meguminnnnnnnnn/go-openai/internal/test/checks" ) const testFineTuninigJobID = "fine-tuning-job-id" diff --git a/go.mod b/go.mod index 42cc7b391..3b781ed20 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ -module github.com/sashabaranov/go-openai +module github.com/meguminnnnnnnnn/go-openai go 1.18 diff --git a/image_api_test.go b/image_api_test.go index f6057b77d..7c35b857a 100644 --- a/image_api_test.go +++ b/image_api_test.go @@ -11,8 +11,8 @@ import ( "testing" "time" - "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test/checks" + "github.com/meguminnnnnnnnn/go-openai" + "github.com/meguminnnnnnnnn/go-openai/internal/test/checks" ) func TestImages(t *testing.T) { diff --git a/image_test.go b/image_test.go index 644005515..f14121695 100644 --- a/image_test.go +++ b/image_test.go @@ -1,8 +1,8 @@ package openai //nolint:testpackage // testing private field import ( - utils "github.com/sashabaranov/go-openai/internal" - "github.com/sashabaranov/go-openai/internal/test/checks" + utils "github.com/meguminnnnnnnnn/go-openai/internal" + "github.com/meguminnnnnnnnn/go-openai/internal/test/checks" "context" "fmt" diff --git a/internal/error_accumulator_test.go b/internal/error_accumulator_test.go index d48f28177..f76ade0b9 100644 --- a/internal/error_accumulator_test.go +++ b/internal/error_accumulator_test.go @@ -5,8 +5,8 @@ import ( "errors" "testing" - utils "github.com/sashabaranov/go-openai/internal" - "github.com/sashabaranov/go-openai/internal/test" + utils "github.com/meguminnnnnnnnn/go-openai/internal" + "github.com/meguminnnnnnnnn/go-openai/internal/test" ) func TestErrorAccumulatorBytes(t *testing.T) { diff --git a/internal/form_builder_test.go b/internal/form_builder_test.go index 76922c1ba..51d4c90ce 100644 --- a/internal/form_builder_test.go +++ b/internal/form_builder_test.go @@ -1,7 +1,7 @@ package openai //nolint:testpackage // testing private field import ( - "github.com/sashabaranov/go-openai/internal/test/checks" + "github.com/meguminnnnnnnnn/go-openai/internal/test/checks" "bytes" "errors" diff --git a/internal/test/helpers.go b/internal/test/helpers.go index dc5fa6646..5c638ef01 100644 --- a/internal/test/helpers.go +++ b/internal/test/helpers.go @@ -1,7 +1,7 @@ package test import ( - "github.com/sashabaranov/go-openai/internal/test/checks" + "github.com/meguminnnnnnnnn/go-openai/internal/test/checks" "net/http" "os" diff --git a/jsonschema/json_test.go b/jsonschema/json_test.go index 84f25fa85..af97dcb46 100644 --- a/jsonschema/json_test.go +++ b/jsonschema/json_test.go @@ -5,7 +5,7 @@ import ( "reflect" "testing" - "github.com/sashabaranov/go-openai/jsonschema" + "github.com/meguminnnnnnnnn/go-openai/jsonschema" ) func TestDefinition_MarshalJSON(t *testing.T) { diff --git a/jsonschema/validate_test.go b/jsonschema/validate_test.go index 6fa30ab0c..026c5e21f 100644 --- a/jsonschema/validate_test.go +++ b/jsonschema/validate_test.go @@ -3,7 +3,7 @@ package jsonschema_test import ( "testing" - "github.com/sashabaranov/go-openai/jsonschema" + "github.com/meguminnnnnnnnn/go-openai/jsonschema" ) func Test_Validate(t *testing.T) { diff --git a/messages_test.go b/messages_test.go index b25755f98..a726adf04 100644 --- a/messages_test.go +++ b/messages_test.go @@ -7,9 +7,9 @@ import ( "net/http" "testing" - "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test" - "github.com/sashabaranov/go-openai/internal/test/checks" + "github.com/meguminnnnnnnnn/go-openai" + "github.com/meguminnnnnnnnn/go-openai/internal/test" + "github.com/meguminnnnnnnnn/go-openai/internal/test/checks" ) var emptyStr = "" diff --git a/models_test.go b/models_test.go index 7fd010c34..ab70a6857 100644 --- a/models_test.go +++ b/models_test.go @@ -9,8 +9,8 @@ import ( "testing" "time" - "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test/checks" + "github.com/meguminnnnnnnnn/go-openai" + "github.com/meguminnnnnnnnn/go-openai/internal/test/checks" ) const testFineTuneModelID = "fine-tune-model-id" diff --git a/moderation_test.go b/moderation_test.go index a97f25bc6..95cb879b7 100644 --- a/moderation_test.go +++ b/moderation_test.go @@ -11,8 +11,8 @@ import ( "testing" "time" - "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test/checks" + "github.com/meguminnnnnnnnn/go-openai" + "github.com/meguminnnnnnnnn/go-openai/internal/test/checks" ) // TestModeration Tests the moderations endpoint of the API using the mocked server. diff --git a/openai_test.go b/openai_test.go index a55f3a858..cabaf10a4 100644 --- a/openai_test.go +++ b/openai_test.go @@ -1,8 +1,8 @@ package openai_test import ( - "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test" + "github.com/meguminnnnnnnnn/go-openai" + "github.com/meguminnnnnnnnn/go-openai/internal/test" ) func setupOpenAITestServer() (client *openai.Client, server *test.ServerTest, teardown func()) { diff --git a/reasoning_validator.go b/reasoning_validator.go index 2910b1395..b8bc51d2c 100644 --- a/reasoning_validator.go +++ b/reasoning_validator.go @@ -61,7 +61,7 @@ func (v *ReasoningValidator) validateReasoningModelParams(request ChatCompletion if request.LogProbs { return ErrReasoningModelLimitationsLogprobs } - if request.Temperature > 0 && request.Temperature != 1 { + if request.Temperature != nil { return ErrReasoningModelLimitationsOther } if request.TopP > 0 && request.TopP != 1 { diff --git a/run_test.go b/run_test.go index cdf99db05..02505e981 100644 --- a/run_test.go +++ b/run_test.go @@ -3,8 +3,8 @@ package openai_test import ( "context" - openai "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test/checks" + openai "github.com/meguminnnnnnnnn/go-openai" + "github.com/meguminnnnnnnnn/go-openai/internal/test/checks" "encoding/json" "fmt" diff --git a/speech_test.go b/speech_test.go index 67a3feabc..3f1cedf47 100644 --- a/speech_test.go +++ b/speech_test.go @@ -11,9 +11,9 @@ import ( "path/filepath" "testing" - "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test" - "github.com/sashabaranov/go-openai/internal/test/checks" + "github.com/meguminnnnnnnnn/go-openai" + "github.com/meguminnnnnnnnn/go-openai/internal/test" + "github.com/meguminnnnnnnnn/go-openai/internal/test/checks" ) func TestSpeechIntegration(t *testing.T) { diff --git a/stream_reader.go b/stream_reader.go index 6faefe0a7..17cf31866 100644 --- a/stream_reader.go +++ b/stream_reader.go @@ -8,7 +8,7 @@ import ( "net/http" "regexp" - utils "github.com/sashabaranov/go-openai/internal" + utils "github.com/meguminnnnnnnnn/go-openai/internal" ) var ( diff --git a/stream_reader_test.go b/stream_reader_test.go index 449a14b43..4098fba08 100644 --- a/stream_reader_test.go +++ b/stream_reader_test.go @@ -6,9 +6,9 @@ import ( "errors" "testing" - utils "github.com/sashabaranov/go-openai/internal" - "github.com/sashabaranov/go-openai/internal/test" - "github.com/sashabaranov/go-openai/internal/test/checks" + utils "github.com/meguminnnnnnnnn/go-openai/internal" + "github.com/meguminnnnnnnnn/go-openai/internal/test" + "github.com/meguminnnnnnnnn/go-openai/internal/test/checks" ) var errTestUnmarshalerFailed = errors.New("test unmarshaler failed") diff --git a/stream_test.go b/stream_test.go index 9dd95bb5f..3156360a0 100644 --- a/stream_test.go +++ b/stream_test.go @@ -10,8 +10,8 @@ import ( "testing" "time" - "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test/checks" + "github.com/meguminnnnnnnnn/go-openai" + "github.com/meguminnnnnnnnn/go-openai/internal/test/checks" ) func TestCompletionsStreamWrongModel(t *testing.T) { diff --git a/thread_test.go b/thread_test.go index 1ac0f3c0e..c8fbe98ce 100644 --- a/thread_test.go +++ b/thread_test.go @@ -7,8 +7,8 @@ import ( "net/http" "testing" - openai "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test/checks" + openai "github.com/meguminnnnnnnnn/go-openai" + "github.com/meguminnnnnnnnn/go-openai/internal/test/checks" ) // TestThread Tests the thread endpoint of the API using the mocked server. diff --git a/vector_store_test.go b/vector_store_test.go index 58b9a857e..2ddaef976 100644 --- a/vector_store_test.go +++ b/vector_store_test.go @@ -3,8 +3,8 @@ package openai_test import ( "context" - openai "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test/checks" + openai "github.com/meguminnnnnnnnn/go-openai" + "github.com/meguminnnnnnnnn/go-openai/internal/test/checks" "encoding/json" "fmt" From b107e16d61a0e6673ca1313098f6c2533ebfbaa8 Mon Sep 17 00:00:00 2001 From: Megumin Date: Tue, 8 Apr 2025 15:15:28 +0800 Subject: [PATCH 2/6] feat: add multi content type --- chat.go | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/chat.go b/chat.go index 3621a48c8..a89b23d9b 100644 --- a/chat.go +++ b/chat.go @@ -79,17 +79,30 @@ type ChatMessageImageURL struct { Detail ImageURLDetail `json:"detail,omitempty"` } +type ChatMessageInputAudio struct { + Data string `json:"data,omitempty"` + Format string `json:"format,omitempty"` +} + +type ChatMessageVideoURL struct { + URL string `json:"url,omitempty"` +} + type ChatMessagePartType string const ( - ChatMessagePartTypeText ChatMessagePartType = "text" - ChatMessagePartTypeImageURL ChatMessagePartType = "image_url" + ChatMessagePartTypeText ChatMessagePartType = "text" + ChatMessagePartTypeImageURL ChatMessagePartType = "image_url" + ChatMessagePartTypeInputAudio ChatMessagePartType = "input_audio" + ChatMessagePartTypeVideoURL ChatMessagePartType = "video_url" ) type ChatMessagePart struct { - Type ChatMessagePartType `json:"type,omitempty"` - Text string `json:"text,omitempty"` - ImageURL *ChatMessageImageURL `json:"image_url,omitempty"` + Type ChatMessagePartType `json:"type,omitempty"` + Text string `json:"text,omitempty"` + ImageURL *ChatMessageImageURL `json:"image_url,omitempty"` + InputAudio *ChatMessageInputAudio `json:"input_audio,omitempty"` + VideoURL *ChatMessageVideoURL `json:"video_url,omitempty"` } type ChatCompletionMessage struct { From 0ec1babd3dd578882bd1ea5c4d6db902e4e1107c Mon Sep 17 00:00:00 2001 From: Back Yu Date: Mon, 19 May 2025 15:41:23 +0800 Subject: [PATCH 3/6] =?UTF-8?q?chat:=20=E6=96=B0=E5=A2=9EExtraFields?= =?UTF-8?q?=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 62 +++++++++++++++++++++++++++++--- chat.go | 24 +++++++++---- chat_test.go | 49 +++++++++++++++++++++++++ go.mod | 4 +++ go.sum | 5 +++ internal/marshaller.go | 30 +++++++++++++++- internal/request_builder.go | 1 + internal/request_builder_test.go | 49 +++++++++++++++++++++++++ 8 files changed, 211 insertions(+), 13 deletions(-) create mode 100644 go.sum diff --git a/README.md b/README.md index 77b85e519..a8eabd06b 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ [![Go Report Card](https://goreportcard.com/badge/github.com/sashabaranov/go-openai)](https://goreportcard.com/report/github.com/sashabaranov/go-openai) [![codecov](https://codecov.io/gh/sashabaranov/go-openai/branch/master/graph/badge.svg?token=bCbIfHLIsW)](https://codecov.io/gh/sashabaranov/go-openai) -This library provides unofficial Go clients for [OpenAI API](https://platform.openai.com/). We support: +This library provides unofficial Go clients for [OpenAI API](https://platform.openai.com/). We support: * ChatGPT 4o, o1 * GPT-3, GPT-4 @@ -720,7 +720,7 @@ if errors.As(err, &e) { case 401: // invalid auth or key (do not retry) case 429: - // rate limiting or engine overload (wait and retry) + // rate limiting or engine overload (wait and retry) case 500: // openai server error (retry) default: @@ -867,6 +867,58 @@ func main() { } ``` + +
+Using ExtraFields + +```go +package main + +import ( + "context" + "fmt" + openai "github.com/sashabaranov/go-openai" +) + +func main() { + client := openai.NewClient("your token") + ctx := context.Background() + + // Create chat request + req := openai.ChatCompletionRequest{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + } + + // Add custom fields + extraFields := map[string]any{ + "custom_field": "test_value", + "numeric_field": 42, + "bool_field": true, + } + req.SetExtraFields(extraFields) + + // Get custom fields + gotFields := req.GetExtraFields() + fmt.Printf("Extra fields: %v\n", gotFields) + + // Send request + resp, err := client.CreateChatCompletion(ctx, req) + if err != nil { + fmt.Printf("ChatCompletion error: %v\n", err) + return + } + + fmt.Println(resp.Choices[0].Message.Content) +} +``` +
+ See the `examples/` folder for more. ## Frequently Asked Questions @@ -887,18 +939,18 @@ Due to the factors mentioned above, different answers may be returned even for t By adopting these strategies, you can expect more consistent results. -**Related Issues:** +**Related Issues:** [omitempty option of request struct will generate incorrect request when parameter is 0.](https://github.com/sashabaranov/go-openai/issues/9) ### Does Go OpenAI provide a method to count tokens? No, Go OpenAI does not offer a feature to count tokens, and there are no plans to provide such a feature in the future. However, if there's a way to implement a token counting feature with zero dependencies, it might be possible to merge that feature into Go OpenAI. Otherwise, it would be more appropriate to implement it in a dedicated library or repository. -For counting tokens, you might find the following links helpful: +For counting tokens, you might find the following links helpful: - [Counting Tokens For Chat API Calls](https://github.com/pkoukk/tiktoken-go#counting-tokens-for-chat-api-calls) - [How to count tokens with tiktoken](https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb) -**Related Issues:** +**Related Issues:** [Is it possible to join the implementation of GPT3 Tokenizer](https://github.com/sashabaranov/go-openai/issues/62) ## Contributing diff --git a/chat.go b/chat.go index a89b23d9b..a5c66644e 100644 --- a/chat.go +++ b/chat.go @@ -286,13 +286,23 @@ type ChatCompletionRequest struct { ReasoningEffort string `json:"reasoning_effort,omitempty"` // Metadata to store with the completion. Metadata map[string]string `json:"metadata,omitempty"` - // Configuration for a predicted output. - Prediction *Prediction `json:"prediction,omitempty"` - // ChatTemplateKwargs provides a way to add non-standard parameters to the request body. - // Additional kwargs to pass to the template renderer. Will be accessible by the chat template. - // Such as think mode for qwen3. "chat_template_kwargs": {"enable_thinking": false} - // https://qwen.readthedocs.io/en/latest/deployment/vllm.html#thinking-non-thinking-modes - ChatTemplateKwargs map[string]any `json:"chat_template_kwargs,omitempty"` + + // Extra fields to be sent in the request. + // Useful for experimental features not yet officially supported. + extraFields map[string]any +} + +// SetExtraFields adds extra fields to the JSON object. +// +// SetExtraFields will override any existing fields with the same key. +// For security reasons, ensure this is only used with trusted input data. +func (r *ChatCompletionRequest) SetExtraFields(extraFields map[string]any) { + r.extraFields = extraFields +} + +// GetExtraFields returns the extra fields set in the request. +func (r *ChatCompletionRequest) GetExtraFields() map[string]any { + return r.extraFields } type StreamOptions struct { diff --git a/chat_test.go b/chat_test.go index 673390087..c587bf80d 100644 --- a/chat_test.go +++ b/chat_test.go @@ -950,3 +950,52 @@ func TestFinishReason(t *testing.T) { } } } + +func TestChatCompletionRequestExtraFields(t *testing.T) { + req := openai.ChatCompletionRequest{ + Model: "gpt-4", + } + + // 测试设置额外字段 + extraFields := map[string]any{ + "custom_field": "test_value", + "numeric_field": 42, + "bool_field": true, + } + req.SetExtraFields(extraFields) + + // 测试获取额外字段 + gotFields := req.GetExtraFields() + + // 验证字段数量 + if len(gotFields) != len(extraFields) { + t.Errorf("Expected %d extra fields, got %d", len(extraFields), len(gotFields)) + } + + // 验证字段值 + for key, expectedValue := range extraFields { + gotValue, exists := gotFields[key] + if !exists { + t.Errorf("Expected field %s not found", key) + continue + } + if gotValue != expectedValue { + t.Errorf("Field %s: expected %v, got %v", key, expectedValue, gotValue) + } + } + + // 测试覆盖已存在的字段 + newFields := map[string]any{ + "custom_field": "new_value", + } + req.SetExtraFields(newFields) + gotFields = req.GetExtraFields() + + if len(gotFields) != len(newFields) { + t.Errorf("Expected %d extra fields after override, got %d", len(newFields), len(gotFields)) + } + + if gotFields["custom_field"] != "new_value" { + t.Errorf("Expected overridden value 'new_value', got %v", gotFields["custom_field"]) + } +} diff --git a/go.mod b/go.mod index 3b781ed20..e7952240e 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,7 @@ module github.com/meguminnnnnnnnn/go-openai go 1.18 + +require github.com/evanphx/json-patch v0.5.2 + +require github.com/pkg/errors v0.9.1 // indirect diff --git a/go.sum b/go.sum new file mode 100644 index 000000000..f41860f17 --- /dev/null +++ b/go.sum @@ -0,0 +1,5 @@ +github.com/evanphx/json-patch v0.5.2 h1:xVCHIVMUu1wtM/VkR9jVZ45N3FhZfYMMYGorLCR8P3k= +github.com/evanphx/json-patch v0.5.2/go.mod h1:ZWS5hhDbVDyob71nXKNL0+PWn6ToqBHMikGIFbs31qQ= +github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= diff --git a/internal/marshaller.go b/internal/marshaller.go index 223a4dc1c..4dc7f8e88 100644 --- a/internal/marshaller.go +++ b/internal/marshaller.go @@ -2,6 +2,9 @@ package openai import ( "encoding/json" + "fmt" + + jsonpatch "github.com/evanphx/json-patch" ) type Marshaller interface { @@ -11,5 +14,30 @@ type Marshaller interface { type JSONMarshaller struct{} func (jm *JSONMarshaller) Marshal(value any) ([]byte, error) { - return json.Marshal(value) + originalBytes, err := json.Marshal(value) + if err != nil { + return nil, err + } + // Check if the value implements the GetExtraFields interface + getExtraFieldsBody, ok := value.(interface { + GetExtraFields() map[string]any + }) + if !ok { + // If not, return the original bytes + return originalBytes, nil + } + extraFields := getExtraFieldsBody.GetExtraFields() + if len(extraFields) == 0 { + // If there are no extra fields, return the original bytes + return originalBytes, nil + } + patchBytes, err := json.Marshal(extraFields) + if err != nil { + return nil, fmt.Errorf("Marshal extraFields(%+v) err: %w", extraFields, err) + } + finalBytes, err := jsonpatch.MergePatch(originalBytes, patchBytes) + if err != nil { + return nil, fmt.Errorf("MergePatch originalBytes(%s) patchBytes(%s) err: %w", originalBytes, patchBytes, err) + } + return finalBytes, nil } diff --git a/internal/request_builder.go b/internal/request_builder.go index 5699f6b18..de3a9814d 100644 --- a/internal/request_builder.go +++ b/internal/request_builder.go @@ -38,6 +38,7 @@ func (b *HTTPRequestBuilder) Build( if err != nil { return } + bodyReader = bytes.NewBuffer(reqBytes) } } diff --git a/internal/request_builder_test.go b/internal/request_builder_test.go index e26022a6b..b235cfb6d 100644 --- a/internal/request_builder_test.go +++ b/internal/request_builder_test.go @@ -3,6 +3,7 @@ package openai //nolint:testpackage // testing private field import ( "bytes" "context" + "encoding/json" "errors" "net/http" "reflect" @@ -59,3 +60,51 @@ func TestRequestBuilderReturnsRequestWhenRequestOfArgsIsNil(t *testing.T) { t.Errorf("Build() got = %v, want %v", got, want) } } + +type testExtraFieldsRequest struct { + Model string `json:"model"` + extraFields map[string]any +} + +func (r *testExtraFieldsRequest) GetExtraFields() map[string]any { + return r.extraFields +} + +func TestRequestBuilderReturnsRequestWhenRequestHasExtraFields(t *testing.T) { + b := NewRequestBuilder() + var ( + ctx = context.Background() + method = http.MethodPost + url = "/foo" + request = &testExtraFieldsRequest{ + Model: "test-model", + } + ) + request.extraFields = map[string]any{"extra_field": "extra_value"} + + reqBytes, err := b.marshaller.Marshal(request) + if err != nil { + t.Fatalf("Marshal failed: %v", err) + } + + // 验证序列化结果包含原始字段和额外字段 + var result map[string]interface{} + if err := json.Unmarshal(reqBytes, &result); err != nil { + t.Fatalf("Unmarshal failed: %v", err) + } + + if result["model"] != "test-model" { + t.Errorf("Expected model to be 'test-model', got %v", result["model"]) + } + if result["extra_field"] != "extra_value" { + t.Errorf("Expected extra_field to be 'extra_value', got %v", result["extra_field"]) + } + + want, _ := http.NewRequestWithContext(ctx, method, url, bytes.NewBuffer(reqBytes)) + got, _ := b.Build(ctx, method, url, request, nil) + if !reflect.DeepEqual(got.Body, want.Body) || + !reflect.DeepEqual(got.URL, want.URL) || + !reflect.DeepEqual(got.Method, want.Method) { + t.Errorf("Build() got = %v, want %v", got, want) + } +} From 65cf08bcac2d39c28bcb868fc185f2df92e6627b Mon Sep 17 00:00:00 2001 From: Back Yu Date: Mon, 19 May 2025 16:18:53 +0800 Subject: [PATCH 4/6] =?UTF-8?q?chat:=20=E4=BF=AE=E5=A4=8DGetExtraFields()?= =?UTF-8?q?=E5=BC=95=E7=94=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- chat.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chat.go b/chat.go index a5c66644e..c83369d77 100644 --- a/chat.go +++ b/chat.go @@ -301,7 +301,7 @@ func (r *ChatCompletionRequest) SetExtraFields(extraFields map[string]any) { } // GetExtraFields returns the extra fields set in the request. -func (r *ChatCompletionRequest) GetExtraFields() map[string]any { +func (r ChatCompletionRequest) GetExtraFields() map[string]any { return r.extraFields } From 3a9c552bd221c45d6a953b11f5756ad20a30d709 Mon Sep 17 00:00:00 2001 From: Megumin Date: Mon, 9 Jun 2025 17:53:47 +0800 Subject: [PATCH 5/6] fix: GetExtraFields & gomod --- audio_test.go | 6 +++--- chat.go | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/audio_test.go b/audio_test.go index 51b3f465d..ac2d65327 100644 --- a/audio_test.go +++ b/audio_test.go @@ -11,9 +11,9 @@ import ( "path/filepath" "testing" - utils "github.com/sashabaranov/go-openai/internal" - "github.com/sashabaranov/go-openai/internal/test" - "github.com/sashabaranov/go-openai/internal/test/checks" + utils "github.com/meguminnnnnnnnn/go-openai/internal" + "github.com/meguminnnnnnnnn/go-openai/internal/test" + "github.com/meguminnnnnnnnn/go-openai/internal/test/checks" ) func TestAudioWithFailingFormBuilder(t *testing.T) { diff --git a/chat.go b/chat.go index c83369d77..a5c66644e 100644 --- a/chat.go +++ b/chat.go @@ -301,7 +301,7 @@ func (r *ChatCompletionRequest) SetExtraFields(extraFields map[string]any) { } // GetExtraFields returns the extra fields set in the request. -func (r ChatCompletionRequest) GetExtraFields() map[string]any { +func (r *ChatCompletionRequest) GetExtraFields() map[string]any { return r.extraFields } From 0d508a1dcddeb9c2d908eb19af1aa9efa5ba2be1 Mon Sep 17 00:00:00 2001 From: Megumin Date: Fri, 20 Jun 2025 17:28:28 +0800 Subject: [PATCH 6/6] fix: pointer to extra fields --- chat.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chat.go b/chat.go index a5c66644e..c83369d77 100644 --- a/chat.go +++ b/chat.go @@ -301,7 +301,7 @@ func (r *ChatCompletionRequest) SetExtraFields(extraFields map[string]any) { } // GetExtraFields returns the extra fields set in the request. -func (r *ChatCompletionRequest) GetExtraFields() map[string]any { +func (r ChatCompletionRequest) GetExtraFields() map[string]any { return r.extraFields }