Skip to content

Commit

Permalink
Add output binding for OpenAI (#2965)
Browse files Browse the repository at this point in the history
Signed-off-by: Shivam Kumar Singh <[email protected]>
Signed-off-by: Alessandro (Ale) Segala <[email protected]>
Signed-off-by: ItalyPaleAle <[email protected]>
Co-authored-by: Alessandro (Ale) Segala <[email protected]>
  • Loading branch information
shivam-51 and ItalyPaleAle authored Jul 13, 2023
1 parent 1ab15ef commit 95045c4
Show file tree
Hide file tree
Showing 6 changed files with 373 additions and 6 deletions.
38 changes: 38 additions & 0 deletions bindings/azure/openai/metadata.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# yaml-language-server: $schema=../../../component-metadata-schema.json
schemaVersion: v1
type: bindings
name: azure.openai
version: v1
status: alpha
title: "Azure OpenAI"
urls:
- title: Reference
url: https://docs.dapr.io/reference/components-reference/supported-bindings/azure-openai/
binding:
output: true
input: false
operations:
- name: completion
description: "Text completion"
- name: chat-completion
description: "Chat completion"
builtinAuthenticationProfiles:
- name: "azuread"
authenticationProfiles:
- title: "API Key"
description: "Authenticate using an API key"
metadata:
- name: apiKey
required: true
sensitive: true
description: "API Key"
example: '"1234567890abcdef"'
metadata:
- name: endpoint
required: true
description: "Endpoint of the Azure OpenAI service"
example: '"https://myopenai.openai.azure.com"'
- name: deploymentID
required: true
description: "ID of the model deployment in the Azure OpenAI service"
example: '"my-model"'
326 changes: 326 additions & 0 deletions bindings/azure/openai/openai.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,326 @@
/*
Copyright 2023 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package openai

import (
"context"
"encoding/json"
"fmt"
"reflect"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/cognitiveservices/azopenai"

"github.com/dapr/components-contrib/bindings"
azauth "github.com/dapr/components-contrib/internal/authentication/azure"
"github.com/dapr/components-contrib/metadata"
"github.com/dapr/kit/config"
"github.com/dapr/kit/logger"
)

// List of operations.
const (
CompletionOperation bindings.OperationKind = "completion"
ChatCompletionOperation bindings.OperationKind = "chat-completion"

APIKey = "apiKey"
DeploymentID = "deploymentID"
Endpoint = "endpoint"
MessagesKey = "messages"
Temperature = "temperature"
MaxTokens = "maxTokens"
TopP = "topP"
N = "n"
Stop = "stop"
FrequencyPenalty = "frequencyPenalty"
LogitBias = "logitBias"
User = "user"
)

// AzOpenAI represents OpenAI output binding.
type AzOpenAI struct {
logger logger.Logger
client *azopenai.Client
}

type openAIMetadata struct {
// APIKey is the API key for the Azure OpenAI API.
APIKey string `mapstructure:"apiKey"`
// DeploymentID is the deployment ID for the Azure OpenAI API.
DeploymentID string `mapstructure:"deploymentID"`
// Endpoint is the endpoint for the Azure OpenAI API.
Endpoint string `mapstructure:"endpoint"`
}

type ChatSettings struct {
Temperature float32 `mapstructure:"temperature"`
MaxTokens int32 `mapstructure:"maxTokens"`
TopP float32 `mapstructure:"topP"`
N int32 `mapstructure:"n"`
PresencePenalty float32 `mapstructure:"presencePenalty"`
FrequencyPenalty float32 `mapstructure:"frequencyPenalty"`
}

// ChatMessages type for chat completion API.
type ChatMessages struct {
Messages []Message `json:"messages"`
Temperature float32 `json:"temperature"`
MaxTokens int32 `json:"maxTokens"`
TopP float32 `json:"topP"`
N int32 `json:"n"`
PresencePenalty float32 `json:"presencePenalty"`
FrequencyPenalty float32 `json:"frequencyPenalty"`
}

// Message type stores the messages for bot conversation.
type Message struct {
Role string
Message string
}

// Prompt type for completion API.
type Prompt struct {
Prompt string `json:"prompt"`
Temperature float32 `json:"temperature"`
MaxTokens int32 `json:"maxTokens"`
TopP float32 `json:"topP"`
N int32 `json:"n"`
PresencePenalty float32 `json:"presencePenalty"`
FrequencyPenalty float32 `json:"frequencyPenalty"`
}

// NewOpenAI returns a new OpenAI output binding.
func NewOpenAI(logger logger.Logger) bindings.OutputBinding {
return &AzOpenAI{
logger: logger,
}
}

// Init initializes the OpenAI binding.
func (p *AzOpenAI) Init(ctx context.Context, meta bindings.Metadata) error {
m := openAIMetadata{}
err := metadata.DecodeMetadata(meta.Properties, &m)
if err != nil {
return fmt.Errorf("error decoding metadata: %w", err)
}
if m.Endpoint == "" {
return fmt.Errorf("required metadata not set: %s", Endpoint)
}
if m.DeploymentID == "" {
return fmt.Errorf("required metadata not set: %s", DeploymentID)
}

if m.APIKey != "" {
// use API key authentication
var keyCredential azopenai.KeyCredential
keyCredential, err = azopenai.NewKeyCredential(m.APIKey)
if err != nil {
return fmt.Errorf("error getting credentials object: %w", err)
}

p.client, err = azopenai.NewClientWithKeyCredential(m.Endpoint, keyCredential, m.DeploymentID, nil)
if err != nil {
return fmt.Errorf("error creating Azure OpenAI client: %w", err)
}
} else {
// fallback to Azure AD authentication
settings, innerErr := azauth.NewEnvironmentSettings(meta.Properties)
if innerErr != nil {
return fmt.Errorf("error creating environment settings: %w", innerErr)
}

token, innerErr := settings.GetTokenCredential()
if innerErr != nil {
return fmt.Errorf("error getting token credential: %w", innerErr)
}

p.client, err = azopenai.NewClient(m.Endpoint, token, m.DeploymentID, nil)
if err != nil {
return fmt.Errorf("error creating Azure OpenAI client: %w", err)
}
}

return nil
}

// Operations returns list of operations supported by OpenAI binding.
func (p *AzOpenAI) Operations() []bindings.OperationKind {
return []bindings.OperationKind{
ChatCompletionOperation,
CompletionOperation,
}
}

// Invoke handles all invoke operations.
func (p *AzOpenAI) Invoke(ctx context.Context, req *bindings.InvokeRequest) (resp *bindings.InvokeResponse, err error) {
if req == nil || len(req.Metadata) == 0 {
return nil, fmt.Errorf("invalid request: metadata is required")
}

startTime := time.Now().UTC()
resp = &bindings.InvokeResponse{
Metadata: map[string]string{
"operation": string(req.Operation),
"start-time": startTime.Format(time.RFC3339Nano),
},
}

switch req.Operation { //nolint:exhaustive
case CompletionOperation:
response, err := p.completion(ctx, req.Data, req.Metadata)
if err != nil {
return nil, fmt.Errorf("error performing completion: %w", err)
}
responseAsBytes, _ := json.Marshal(response)
resp.Data = responseAsBytes

case ChatCompletionOperation:
response, err := p.chatCompletion(ctx, req.Data, req.Metadata)
if err != nil {
return nil, fmt.Errorf("error performing chat completion: %w", err)
}
responseAsBytes, _ := json.Marshal(response)
resp.Data = responseAsBytes

default:
return nil, fmt.Errorf(
"invalid operation type: %s. Expected %s, %s",
req.Operation, CompletionOperation, ChatCompletionOperation,
)
}

endTime := time.Now().UTC()
resp.Metadata["end-time"] = endTime.Format(time.RFC3339Nano)
resp.Metadata["duration"] = endTime.Sub(startTime).String()

return resp, nil
}

func (s *ChatSettings) Decode(in any) error {
return config.Decode(in, s)
}

func (p *AzOpenAI) completion(ctx context.Context, message []byte, metadata map[string]string) (response []azopenai.Choice, err error) {
prompt := Prompt{
Temperature: 1.0,
TopP: 1.0,
MaxTokens: 16,
N: 1,
PresencePenalty: 0.0,
FrequencyPenalty: 0.0,
}
err = json.Unmarshal(message, &prompt)
if err != nil {
return nil, fmt.Errorf("error unmarshalling the input object: %w", err)
}

if prompt.Prompt == "" {
return nil, fmt.Errorf("prompt is required for completion operation")
}

resp, err := p.client.GetCompletions(ctx, azopenai.CompletionsOptions{
Prompt: []*string{&prompt.Prompt},
MaxTokens: &prompt.MaxTokens,
Temperature: &prompt.Temperature,
TopP: &prompt.TopP,
N: &prompt.N,
}, nil)
if err != nil {
return nil, fmt.Errorf("error getting completion api: %w", err)
}

// No choices returned
if len(resp.Completions.Choices) == 0 {
return []azopenai.Choice{}, nil
}

choices := resp.Completions.Choices
response = make([]azopenai.Choice, len(choices))
for i, c := range choices {
response[i] = *c
}

return response, nil
}

func (p *AzOpenAI) chatCompletion(ctx context.Context, messageRequest []byte, metadata map[string]string) (response []azopenai.ChatChoice, err error) {
messages := ChatMessages{
Temperature: 1.0,
TopP: 1.0,
N: 1,
PresencePenalty: 0.0,
FrequencyPenalty: 0.0,
}
err = json.Unmarshal(messageRequest, &messages)
if err != nil {
return nil, fmt.Errorf("error unmarshalling the input object: %w", err)
}

if len(messages.Messages) == 0 {
return nil, fmt.Errorf("messages are required for chat-completion operation")
}

messageReq := make([]*azopenai.ChatMessage, len(messages.Messages))
for i, m := range messages.Messages {
messageReq[i] = &azopenai.ChatMessage{
Role: to.Ptr(azopenai.ChatRole(m.Role)),
Content: to.Ptr(m.Message),
}
}

var maxTokens *int32
if messages.MaxTokens != 0 {
maxTokens = &messages.MaxTokens
}

res, err := p.client.GetChatCompletions(ctx, azopenai.ChatCompletionsOptions{
MaxTokens: maxTokens,
Temperature: &messages.Temperature,
TopP: &messages.TopP,
N: &messages.N,
Messages: messageReq,
}, nil)
if err != nil {
return nil, fmt.Errorf("error getting chat completion api: %w", err)
}

// No choices returned.
if len(res.ChatCompletions.Choices) == 0 {
return []azopenai.ChatChoice{}, nil
}

choices := res.ChatCompletions.Choices
response = make([]azopenai.ChatChoice, len(choices))
for i, c := range choices {
response[i] = *c
}

return response, nil
}

// Close Az OpenAI instance.
func (p *AzOpenAI) Close() error {
p.client = nil

return nil
}

// GetComponentMetadata returns the metadata of the component.
func (p *AzOpenAI) GetComponentMetadata() map[string]string {
metadataStruct := openAIMetadata{}
metadataInfo := map[string]string{}
metadata.GetMetadataInfoFromStructType(reflect.TypeOf(metadataStruct), &metadataInfo, metadata.BindingType)
return metadataInfo
}
3 changes: 2 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ require (
cloud.google.com/go/secretmanager v1.10.0
cloud.google.com/go/storage v1.30.1
dubbo.apache.org/dubbo-go/v3 v3.0.3-0.20230118042253-4f159a2b38f3
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.6.0
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.7.0
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.0
github.com/Azure/azure-sdk-for-go/sdk/cognitiveservices/azopenai v0.0.0-20230705184009-934612c4f2b5
github.com/Azure/azure-sdk-for-go/sdk/data/azappconfig v0.5.0
github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos v0.3.5
github.com/Azure/azure-sdk-for-go/sdk/data/aztables v1.0.1
Expand Down
6 changes: 4 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -420,11 +420,13 @@ github.com/Azure/azure-sdk-for-go v68.0.0+incompatible h1:fcYLmCpyNYRnvJbPerq7U0
github.com/Azure/azure-sdk-for-go v68.0.0+incompatible/go.mod h1:9XXNKU+eRnpl9moKnB4QOLf1HestfXbmab5FXxiDBjc=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.0.0/go.mod h1:uGG2W01BaETf0Ozp+QxxKJdMBNRWPdstHG0Fmdwn1/U=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.1.2/go.mod h1:uGG2W01BaETf0Ozp+QxxKJdMBNRWPdstHG0Fmdwn1/U=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.6.0 h1:8kDqDngH+DmVBiCtIjCFTGa7MBnsIOkF9IccInFEbjk=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.6.0/go.mod h1:bjGvMhVMb+EEm3VRNQawDMUyMMjo+S5ewNjflkep/0Q=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.7.0 h1:8q4SaHjFsClSvuVne0ID/5Ka8u3fcIHyqkLjcFpNRHQ=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.7.0/go.mod h1:bjGvMhVMb+EEm3VRNQawDMUyMMjo+S5ewNjflkep/0Q=
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.2.1/go.mod h1:gLa1CL2RNE4s7M3yopJ/p0iq5DdY6Yv5ZUt9MTRZOQM=
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.0 h1:vcYCAze6p19qBW7MhZybIsqD8sMV8js0NyQM8JDnVtg=
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.0/go.mod h1:OQeznEEkTZ9OrhHJoDD8ZDq51FHgXjqtP9z6bEwBq9U=
github.com/Azure/azure-sdk-for-go/sdk/cognitiveservices/azopenai v0.0.0-20230705184009-934612c4f2b5 h1:DQCZXtoCPuwBMlAa2aC+B3CfpE6xz2xe1jqdqt8nIJY=
github.com/Azure/azure-sdk-for-go/sdk/cognitiveservices/azopenai v0.0.0-20230705184009-934612c4f2b5/go.mod h1:GQSjs1n073tbMa3e76+STZkyFb+NcEA4N7OB5vNvB3E=
github.com/Azure/azure-sdk-for-go/sdk/data/azappconfig v0.5.0 h1:OrKZybbyagpgJiREiIVzH5mV/z9oS4rXqdX7i31DSF0=
github.com/Azure/azure-sdk-for-go/sdk/data/azappconfig v0.5.0/go.mod h1:p74+tP95m8830ypJk53L93+BEsjTKY4SKQ75J2NmS5U=
github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos v0.3.5 h1:qS0Bp4do0cIvnuQgSGeO6ZCu/q/HlRKl4NPfv1eJ2p0=
Expand Down
Loading

0 comments on commit 95045c4

Please sign in to comment.