From 9957d6969ddada8b943989adf8dd9422b69eb81a Mon Sep 17 00:00:00 2001 From: Shivam Kumar Singh Date: Fri, 21 Jul 2023 19:56:10 +0530 Subject: [PATCH] Fix lint failure for openai binding (#3000) Signed-off-by: Shivam Kumar Singh Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com> Co-authored-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com> --- bindings/azure/openai/openai.go | 80 +++++++++++++++++---------------- go.mod | 2 +- go.sum | 4 +- 3 files changed, 44 insertions(+), 42 deletions(-) diff --git a/bindings/azure/openai/openai.go b/bindings/azure/openai/openai.go index 9df8009a33..81743b5c81 100644 --- a/bindings/azure/openai/openai.go +++ b/bindings/azure/openai/openai.go @@ -26,7 +26,6 @@ import ( "github.com/dapr/components-contrib/bindings" azauth "github.com/dapr/components-contrib/internal/authentication/azure" "github.com/dapr/components-contrib/metadata" - "github.com/dapr/kit/config" "github.com/dapr/kit/logger" ) @@ -51,8 +50,9 @@ const ( // AzOpenAI represents OpenAI output binding. type AzOpenAI struct { - logger logger.Logger - client *azopenai.Client + logger logger.Logger + client *azopenai.Client + deploymentID string } type openAIMetadata struct { @@ -64,15 +64,6 @@ type openAIMetadata struct { Endpoint string `mapstructure:"endpoint"` } -type ChatSettings struct { - Temperature float32 `mapstructure:"temperature"` - MaxTokens int32 `mapstructure:"maxTokens"` - TopP float32 `mapstructure:"topP"` - N int32 `mapstructure:"n"` - PresencePenalty float32 `mapstructure:"presencePenalty"` - FrequencyPenalty float32 `mapstructure:"frequencyPenalty"` -} - // ChatMessages type for chat completion API. type ChatMessages struct { Messages []Message `json:"messages"` @@ -82,6 +73,7 @@ type ChatMessages struct { N int32 `json:"n"` PresencePenalty float32 `json:"presencePenalty"` FrequencyPenalty float32 `json:"frequencyPenalty"` + Stop []string `json:"stop"` } // Message type stores the messages for bot conversation. @@ -92,13 +84,14 @@ type Message struct { // Prompt type for completion API. type Prompt struct { - Prompt string `json:"prompt"` - Temperature float32 `json:"temperature"` - MaxTokens int32 `json:"maxTokens"` - TopP float32 `json:"topP"` - N int32 `json:"n"` - PresencePenalty float32 `json:"presencePenalty"` - FrequencyPenalty float32 `json:"frequencyPenalty"` + Prompt string `json:"prompt"` + Temperature float32 `json:"temperature"` + MaxTokens int32 `json:"maxTokens"` + TopP float32 `json:"topP"` + N int32 `json:"n"` + PresencePenalty float32 `json:"presencePenalty"` + FrequencyPenalty float32 `json:"frequencyPenalty"` + Stop []string `json:"stop"` } // NewOpenAI returns a new OpenAI output binding. @@ -130,7 +123,7 @@ func (p *AzOpenAI) Init(ctx context.Context, meta bindings.Metadata) error { return fmt.Errorf("error getting credentials object: %w", err) } - p.client, err = azopenai.NewClientWithKeyCredential(m.Endpoint, keyCredential, m.DeploymentID, nil) + p.client, err = azopenai.NewClientWithKeyCredential(m.Endpoint, keyCredential, nil) if err != nil { return fmt.Errorf("error creating Azure OpenAI client: %w", err) } @@ -146,11 +139,12 @@ func (p *AzOpenAI) Init(ctx context.Context, meta bindings.Metadata) error { return fmt.Errorf("error getting token credential: %w", innerErr) } - p.client, err = azopenai.NewClient(m.Endpoint, token, m.DeploymentID, nil) + p.client, err = azopenai.NewClient(m.Endpoint, token, nil) if err != nil { return fmt.Errorf("error creating Azure OpenAI client: %w", err) } } + p.deploymentID = m.DeploymentID return nil } @@ -208,10 +202,6 @@ func (p *AzOpenAI) Invoke(ctx context.Context, req *bindings.InvokeRequest) (res return resp, nil } -func (s *ChatSettings) Decode(in any) error { - return config.Decode(in, s) -} - func (p *AzOpenAI) completion(ctx context.Context, message []byte, metadata map[string]string) (response []azopenai.Choice, err error) { prompt := Prompt{ Temperature: 1.0, @@ -230,12 +220,18 @@ func (p *AzOpenAI) completion(ctx context.Context, message []byte, metadata map[ return nil, fmt.Errorf("prompt is required for completion operation") } + if len(prompt.Stop) == 0 { + prompt.Stop = nil + } + resp, err := p.client.GetCompletions(ctx, azopenai.CompletionsOptions{ - Prompt: []*string{&prompt.Prompt}, - MaxTokens: &prompt.MaxTokens, - Temperature: &prompt.Temperature, - TopP: &prompt.TopP, - N: &prompt.N, + DeploymentID: p.deploymentID, + Prompt: []string{prompt.Prompt}, + MaxTokens: &prompt.MaxTokens, + Temperature: &prompt.Temperature, + TopP: &prompt.TopP, + N: &prompt.N, + Stop: prompt.Stop, }, nil) if err != nil { return nil, fmt.Errorf("error getting completion api: %w", err) @@ -249,7 +245,7 @@ func (p *AzOpenAI) completion(ctx context.Context, message []byte, metadata map[ choices := resp.Completions.Choices response = make([]azopenai.Choice, len(choices)) for i, c := range choices { - response[i] = *c + response[i] = c } return response, nil @@ -272,9 +268,13 @@ func (p *AzOpenAI) chatCompletion(ctx context.Context, messageRequest []byte, me return nil, fmt.Errorf("messages are required for chat-completion operation") } - messageReq := make([]*azopenai.ChatMessage, len(messages.Messages)) + if len(messages.Stop) == 0 { + messages.Stop = nil + } + + messageReq := make([]azopenai.ChatMessage, len(messages.Messages)) for i, m := range messages.Messages { - messageReq[i] = &azopenai.ChatMessage{ + messageReq[i] = azopenai.ChatMessage{ Role: to.Ptr(azopenai.ChatRole(m.Role)), Content: to.Ptr(m.Message), } @@ -286,11 +286,13 @@ func (p *AzOpenAI) chatCompletion(ctx context.Context, messageRequest []byte, me } res, err := p.client.GetChatCompletions(ctx, azopenai.ChatCompletionsOptions{ - MaxTokens: maxTokens, - Temperature: &messages.Temperature, - TopP: &messages.TopP, - N: &messages.N, - Messages: messageReq, + DeploymentID: p.deploymentID, + MaxTokens: maxTokens, + Temperature: &messages.Temperature, + TopP: &messages.TopP, + N: &messages.N, + Messages: messageReq, + Stop: messages.Stop, }, nil) if err != nil { return nil, fmt.Errorf("error getting chat completion api: %w", err) @@ -304,7 +306,7 @@ func (p *AzOpenAI) chatCompletion(ctx context.Context, messageRequest []byte, me choices := res.ChatCompletions.Choices response = make([]azopenai.ChatChoice, len(choices)) for i, c := range choices { - response[i] = *c + response[i] = c } return response, nil diff --git a/go.mod b/go.mod index bd2e80e7ab..cb9ec16fbd 100644 --- a/go.mod +++ b/go.mod @@ -10,7 +10,7 @@ require ( dubbo.apache.org/dubbo-go/v3 v3.0.3-0.20230118042253-4f159a2b38f3 github.com/Azure/azure-sdk-for-go/sdk/azcore v1.7.0 github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.0 - github.com/Azure/azure-sdk-for-go/sdk/cognitiveservices/azopenai v0.0.0-20230705184009-934612c4f2b5 + github.com/Azure/azure-sdk-for-go/sdk/cognitiveservices/azopenai v0.1.0 github.com/Azure/azure-sdk-for-go/sdk/data/azappconfig v0.5.0 github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos v0.3.5 github.com/Azure/azure-sdk-for-go/sdk/data/aztables v1.0.1 diff --git a/go.sum b/go.sum index 0bb5c3dc12..b8cd36797a 100644 --- a/go.sum +++ b/go.sum @@ -422,8 +422,8 @@ github.com/Azure/azure-sdk-for-go/sdk/azcore v1.7.0 h1:8q4SaHjFsClSvuVne0ID/5Ka8 github.com/Azure/azure-sdk-for-go/sdk/azcore v1.7.0/go.mod h1:bjGvMhVMb+EEm3VRNQawDMUyMMjo+S5ewNjflkep/0Q= github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.0 h1:vcYCAze6p19qBW7MhZybIsqD8sMV8js0NyQM8JDnVtg= github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.0/go.mod h1:OQeznEEkTZ9OrhHJoDD8ZDq51FHgXjqtP9z6bEwBq9U= -github.com/Azure/azure-sdk-for-go/sdk/cognitiveservices/azopenai v0.0.0-20230705184009-934612c4f2b5 h1:DQCZXtoCPuwBMlAa2aC+B3CfpE6xz2xe1jqdqt8nIJY= -github.com/Azure/azure-sdk-for-go/sdk/cognitiveservices/azopenai v0.0.0-20230705184009-934612c4f2b5/go.mod h1:GQSjs1n073tbMa3e76+STZkyFb+NcEA4N7OB5vNvB3E= +github.com/Azure/azure-sdk-for-go/sdk/cognitiveservices/azopenai v0.1.0 h1:lkflJSWI6jicmEBImjpliUOWCr1PdJO/GcZj3bWx19Q= +github.com/Azure/azure-sdk-for-go/sdk/cognitiveservices/azopenai v0.1.0/go.mod h1:NwVkXm5Ty88Xd7cx6b53fGNeGG3W3ZDXgOXBNHLUy84= github.com/Azure/azure-sdk-for-go/sdk/data/azappconfig v0.5.0 h1:OrKZybbyagpgJiREiIVzH5mV/z9oS4rXqdX7i31DSF0= github.com/Azure/azure-sdk-for-go/sdk/data/azappconfig v0.5.0/go.mod h1:p74+tP95m8830ypJk53L93+BEsjTKY4SKQ75J2NmS5U= github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos v0.3.5 h1:qS0Bp4do0cIvnuQgSGeO6ZCu/q/HlRKl4NPfv1eJ2p0=