Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix lint failure for openai binding #3000

Merged
merged 7 commits into from
Jul 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 41 additions & 39 deletions bindings/azure/openai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import (
"github.com/dapr/components-contrib/bindings"
azauth "github.com/dapr/components-contrib/internal/authentication/azure"
"github.com/dapr/components-contrib/metadata"
"github.com/dapr/kit/config"
"github.com/dapr/kit/logger"
)

Expand All @@ -51,8 +50,9 @@ const (

// AzOpenAI represents OpenAI output binding.
type AzOpenAI struct {
logger logger.Logger
client *azopenai.Client
logger logger.Logger
client *azopenai.Client
deploymentID string
}

type openAIMetadata struct {
Expand All @@ -64,15 +64,6 @@ type openAIMetadata struct {
Endpoint string `mapstructure:"endpoint"`
}

type ChatSettings struct {
Temperature float32 `mapstructure:"temperature"`
MaxTokens int32 `mapstructure:"maxTokens"`
TopP float32 `mapstructure:"topP"`
N int32 `mapstructure:"n"`
PresencePenalty float32 `mapstructure:"presencePenalty"`
FrequencyPenalty float32 `mapstructure:"frequencyPenalty"`
}

// ChatMessages type for chat completion API.
type ChatMessages struct {
Messages []Message `json:"messages"`
Expand All @@ -82,6 +73,7 @@ type ChatMessages struct {
N int32 `json:"n"`
PresencePenalty float32 `json:"presencePenalty"`
FrequencyPenalty float32 `json:"frequencyPenalty"`
Stop []string `json:"stop"`
}

// Message type stores the messages for bot conversation.
Expand All @@ -92,13 +84,14 @@ type Message struct {

// Prompt type for completion API.
type Prompt struct {
Prompt string `json:"prompt"`
Temperature float32 `json:"temperature"`
MaxTokens int32 `json:"maxTokens"`
TopP float32 `json:"topP"`
N int32 `json:"n"`
PresencePenalty float32 `json:"presencePenalty"`
FrequencyPenalty float32 `json:"frequencyPenalty"`
Prompt string `json:"prompt"`
Temperature float32 `json:"temperature"`
MaxTokens int32 `json:"maxTokens"`
TopP float32 `json:"topP"`
N int32 `json:"n"`
PresencePenalty float32 `json:"presencePenalty"`
FrequencyPenalty float32 `json:"frequencyPenalty"`
Stop []string `json:"stop"`
}

// NewOpenAI returns a new OpenAI output binding.
Expand Down Expand Up @@ -130,7 +123,7 @@ func (p *AzOpenAI) Init(ctx context.Context, meta bindings.Metadata) error {
return fmt.Errorf("error getting credentials object: %w", err)
}

p.client, err = azopenai.NewClientWithKeyCredential(m.Endpoint, keyCredential, m.DeploymentID, nil)
p.client, err = azopenai.NewClientWithKeyCredential(m.Endpoint, keyCredential, nil)
if err != nil {
return fmt.Errorf("error creating Azure OpenAI client: %w", err)
}
Expand All @@ -146,11 +139,12 @@ func (p *AzOpenAI) Init(ctx context.Context, meta bindings.Metadata) error {
return fmt.Errorf("error getting token credential: %w", innerErr)
}

p.client, err = azopenai.NewClient(m.Endpoint, token, m.DeploymentID, nil)
p.client, err = azopenai.NewClient(m.Endpoint, token, nil)
if err != nil {
return fmt.Errorf("error creating Azure OpenAI client: %w", err)
}
}
p.deploymentID = m.DeploymentID

return nil
}
Expand Down Expand Up @@ -208,10 +202,6 @@ func (p *AzOpenAI) Invoke(ctx context.Context, req *bindings.InvokeRequest) (res
return resp, nil
}

func (s *ChatSettings) Decode(in any) error {
return config.Decode(in, s)
}

func (p *AzOpenAI) completion(ctx context.Context, message []byte, metadata map[string]string) (response []azopenai.Choice, err error) {
prompt := Prompt{
Temperature: 1.0,
Expand All @@ -230,12 +220,18 @@ func (p *AzOpenAI) completion(ctx context.Context, message []byte, metadata map[
return nil, fmt.Errorf("prompt is required for completion operation")
}

if len(prompt.Stop) == 0 {
prompt.Stop = nil
}

resp, err := p.client.GetCompletions(ctx, azopenai.CompletionsOptions{
Prompt: []*string{&prompt.Prompt},
MaxTokens: &prompt.MaxTokens,
Temperature: &prompt.Temperature,
TopP: &prompt.TopP,
N: &prompt.N,
DeploymentID: p.deploymentID,
Prompt: []string{prompt.Prompt},
MaxTokens: &prompt.MaxTokens,
Temperature: &prompt.Temperature,
TopP: &prompt.TopP,
N: &prompt.N,
Stop: prompt.Stop,
}, nil)
if err != nil {
return nil, fmt.Errorf("error getting completion api: %w", err)
Expand All @@ -249,7 +245,7 @@ func (p *AzOpenAI) completion(ctx context.Context, message []byte, metadata map[
choices := resp.Completions.Choices
response = make([]azopenai.Choice, len(choices))
for i, c := range choices {
response[i] = *c
response[i] = c
}

return response, nil
Expand All @@ -272,9 +268,13 @@ func (p *AzOpenAI) chatCompletion(ctx context.Context, messageRequest []byte, me
return nil, fmt.Errorf("messages are required for chat-completion operation")
}

messageReq := make([]*azopenai.ChatMessage, len(messages.Messages))
if len(messages.Stop) == 0 {
messages.Stop = nil
}

messageReq := make([]azopenai.ChatMessage, len(messages.Messages))
for i, m := range messages.Messages {
messageReq[i] = &azopenai.ChatMessage{
messageReq[i] = azopenai.ChatMessage{
Role: to.Ptr(azopenai.ChatRole(m.Role)),
Content: to.Ptr(m.Message),
}
Expand All @@ -286,11 +286,13 @@ func (p *AzOpenAI) chatCompletion(ctx context.Context, messageRequest []byte, me
}

res, err := p.client.GetChatCompletions(ctx, azopenai.ChatCompletionsOptions{
MaxTokens: maxTokens,
Temperature: &messages.Temperature,
TopP: &messages.TopP,
N: &messages.N,
Messages: messageReq,
DeploymentID: p.deploymentID,
MaxTokens: maxTokens,
Temperature: &messages.Temperature,
TopP: &messages.TopP,
N: &messages.N,
Messages: messageReq,
Stop: messages.Stop,
}, nil)
if err != nil {
return nil, fmt.Errorf("error getting chat completion api: %w", err)
Expand All @@ -304,7 +306,7 @@ func (p *AzOpenAI) chatCompletion(ctx context.Context, messageRequest []byte, me
choices := res.ChatCompletions.Choices
response = make([]azopenai.ChatChoice, len(choices))
for i, c := range choices {
response[i] = *c
response[i] = c
}

return response, nil
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ require (
dubbo.apache.org/dubbo-go/v3 v3.0.3-0.20230118042253-4f159a2b38f3
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.7.0
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.0
github.com/Azure/azure-sdk-for-go/sdk/cognitiveservices/azopenai v0.0.0-20230705184009-934612c4f2b5
github.com/Azure/azure-sdk-for-go/sdk/cognitiveservices/azopenai v0.1.0
github.com/Azure/azure-sdk-for-go/sdk/data/azappconfig v0.5.0
github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos v0.3.5
github.com/Azure/azure-sdk-for-go/sdk/data/aztables v1.0.1
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -422,8 +422,8 @@ github.com/Azure/azure-sdk-for-go/sdk/azcore v1.7.0 h1:8q4SaHjFsClSvuVne0ID/5Ka8
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.7.0/go.mod h1:bjGvMhVMb+EEm3VRNQawDMUyMMjo+S5ewNjflkep/0Q=
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.0 h1:vcYCAze6p19qBW7MhZybIsqD8sMV8js0NyQM8JDnVtg=
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.0/go.mod h1:OQeznEEkTZ9OrhHJoDD8ZDq51FHgXjqtP9z6bEwBq9U=
github.com/Azure/azure-sdk-for-go/sdk/cognitiveservices/azopenai v0.0.0-20230705184009-934612c4f2b5 h1:DQCZXtoCPuwBMlAa2aC+B3CfpE6xz2xe1jqdqt8nIJY=
github.com/Azure/azure-sdk-for-go/sdk/cognitiveservices/azopenai v0.0.0-20230705184009-934612c4f2b5/go.mod h1:GQSjs1n073tbMa3e76+STZkyFb+NcEA4N7OB5vNvB3E=
github.com/Azure/azure-sdk-for-go/sdk/cognitiveservices/azopenai v0.1.0 h1:lkflJSWI6jicmEBImjpliUOWCr1PdJO/GcZj3bWx19Q=
github.com/Azure/azure-sdk-for-go/sdk/cognitiveservices/azopenai v0.1.0/go.mod h1:NwVkXm5Ty88Xd7cx6b53fGNeGG3W3ZDXgOXBNHLUy84=
github.com/Azure/azure-sdk-for-go/sdk/data/azappconfig v0.5.0 h1:OrKZybbyagpgJiREiIVzH5mV/z9oS4rXqdX7i31DSF0=
github.com/Azure/azure-sdk-for-go/sdk/data/azappconfig v0.5.0/go.mod h1:p74+tP95m8830ypJk53L93+BEsjTKY4SKQ75J2NmS5U=
github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos v0.3.5 h1:qS0Bp4do0cIvnuQgSGeO6ZCu/q/HlRKl4NPfv1eJ2p0=
Expand Down
Loading