diff --git a/go.mod b/go.mod index 0d94208e6..6edefa29b 100644 --- a/go.mod +++ b/go.mod @@ -64,7 +64,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/sts v1.33.19 // indirect github.com/bahlo/generic-list-go v0.2.0 // indirect github.com/buger/jsonparser v1.1.1 // indirect - github.com/ebitengine/purego v0.8.2 // indirect + github.com/ebitengine/purego v0.8.4 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/go-logr/logr v1.4.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect diff --git a/go.sum b/go.sum index 62020ac92..8bcdefcb6 100644 --- a/go.sum +++ b/go.sum @@ -60,8 +60,8 @@ github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6N github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/ebitengine/purego v0.8.2 h1:jPPGWs2sZ1UgOSgD2bClL0MJIqu58nOmIcBuXr62z1I= -github.com/ebitengine/purego v0.8.2/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ= +github.com/ebitengine/purego v0.8.4 h1:CF7LEKg5FFOsASUj0+QwaXf8Ht6TlFxg09+S9wz0omw= +github.com/ebitengine/purego v0.8.4/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ= github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc= github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= diff --git a/pkg/waveai/openaibackend.go b/pkg/waveai/openaibackend.go index 36135884a..ca22fe768 100644 --- a/pkg/waveai/openaibackend.go +++ b/pkg/waveai/openaibackend.go @@ -55,6 +55,10 @@ func setApiType(opts *wshrpc.WaveAIOptsType, clientConfig *openaiapi.ClientConfi func convertPrompt(prompt []wshrpc.WaveAIPromptMessageType) []openaiapi.ChatCompletionMessage { var rtn []openaiapi.ChatCompletionMessage for _, p := range prompt { + // Filter out "error" role messages - they are not valid OpenAI roles + if p.Role == "error" { + continue + } msg := openaiapi.ChatCompletionMessage{Role: p.Role, Content: p.Content, Name: p.Name} rtn = append(rtn, msg) } @@ -106,8 +110,11 @@ func (OpenAIBackend) StreamCompletion(ctx context.Context, request wshrpc.WaveAI Messages: convertPrompt(request.Prompt), } - // Handle o1 models differently - use non-streaming API - if strings.HasPrefix(request.Opts.Model, "o1-") { + // Handle o1 and newer models (gpt-4.1+, o4+, o3+) - use non-streaming API with max_completion_tokens + if strings.HasPrefix(request.Opts.Model, "o1-") || + strings.HasPrefix(request.Opts.Model, "gpt-4.1") || + strings.HasPrefix(request.Opts.Model, "o4-") || + strings.HasPrefix(request.Opts.Model, "o3-") { req.MaxCompletionTokens = request.Opts.MaxTokens req.Stream = false