-
Notifications
You must be signed in to change notification settings - Fork 478
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add output binding for OpenAI (#2965)
Signed-off-by: Shivam Kumar Singh <[email protected]> Signed-off-by: Alessandro (Ale) Segala <[email protected]> Signed-off-by: ItalyPaleAle <[email protected]> Co-authored-by: Alessandro (Ale) Segala <[email protected]>
- Loading branch information
1 parent
1ab15ef
commit 95045c4
Showing
6 changed files
with
373 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# yaml-language-server: $schema=../../../component-metadata-schema.json | ||
schemaVersion: v1 | ||
type: bindings | ||
name: azure.openai | ||
version: v1 | ||
status: alpha | ||
title: "Azure OpenAI" | ||
urls: | ||
- title: Reference | ||
url: https://docs.dapr.io/reference/components-reference/supported-bindings/azure-openai/ | ||
binding: | ||
output: true | ||
input: false | ||
operations: | ||
- name: completion | ||
description: "Text completion" | ||
- name: chat-completion | ||
description: "Chat completion" | ||
builtinAuthenticationProfiles: | ||
- name: "azuread" | ||
authenticationProfiles: | ||
- title: "API Key" | ||
description: "Authenticate using an API key" | ||
metadata: | ||
- name: apiKey | ||
required: true | ||
sensitive: true | ||
description: "API Key" | ||
example: '"1234567890abcdef"' | ||
metadata: | ||
- name: endpoint | ||
required: true | ||
description: "Endpoint of the Azure OpenAI service" | ||
example: '"https://myopenai.openai.azure.com"' | ||
- name: deploymentID | ||
required: true | ||
description: "ID of the model deployment in the Azure OpenAI service" | ||
example: '"my-model"' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,326 @@ | ||
/* | ||
Copyright 2023 The Dapr Authors | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
*/ | ||
|
||
package openai | ||
|
||
import ( | ||
"context" | ||
"encoding/json" | ||
"fmt" | ||
"reflect" | ||
"time" | ||
|
||
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to" | ||
"github.com/Azure/azure-sdk-for-go/sdk/cognitiveservices/azopenai" | ||
|
||
"github.com/dapr/components-contrib/bindings" | ||
azauth "github.com/dapr/components-contrib/internal/authentication/azure" | ||
"github.com/dapr/components-contrib/metadata" | ||
"github.com/dapr/kit/config" | ||
"github.com/dapr/kit/logger" | ||
) | ||
|
||
// List of operations. | ||
const ( | ||
CompletionOperation bindings.OperationKind = "completion" | ||
ChatCompletionOperation bindings.OperationKind = "chat-completion" | ||
|
||
APIKey = "apiKey" | ||
DeploymentID = "deploymentID" | ||
Endpoint = "endpoint" | ||
MessagesKey = "messages" | ||
Temperature = "temperature" | ||
MaxTokens = "maxTokens" | ||
TopP = "topP" | ||
N = "n" | ||
Stop = "stop" | ||
FrequencyPenalty = "frequencyPenalty" | ||
LogitBias = "logitBias" | ||
User = "user" | ||
) | ||
|
||
// AzOpenAI represents OpenAI output binding. | ||
type AzOpenAI struct { | ||
logger logger.Logger | ||
client *azopenai.Client | ||
} | ||
|
||
type openAIMetadata struct { | ||
// APIKey is the API key for the Azure OpenAI API. | ||
APIKey string `mapstructure:"apiKey"` | ||
// DeploymentID is the deployment ID for the Azure OpenAI API. | ||
DeploymentID string `mapstructure:"deploymentID"` | ||
// Endpoint is the endpoint for the Azure OpenAI API. | ||
Endpoint string `mapstructure:"endpoint"` | ||
} | ||
|
||
type ChatSettings struct { | ||
Temperature float32 `mapstructure:"temperature"` | ||
MaxTokens int32 `mapstructure:"maxTokens"` | ||
TopP float32 `mapstructure:"topP"` | ||
N int32 `mapstructure:"n"` | ||
PresencePenalty float32 `mapstructure:"presencePenalty"` | ||
FrequencyPenalty float32 `mapstructure:"frequencyPenalty"` | ||
} | ||
|
||
// ChatMessages type for chat completion API. | ||
type ChatMessages struct { | ||
Messages []Message `json:"messages"` | ||
Temperature float32 `json:"temperature"` | ||
MaxTokens int32 `json:"maxTokens"` | ||
TopP float32 `json:"topP"` | ||
N int32 `json:"n"` | ||
PresencePenalty float32 `json:"presencePenalty"` | ||
FrequencyPenalty float32 `json:"frequencyPenalty"` | ||
} | ||
|
||
// Message type stores the messages for bot conversation. | ||
type Message struct { | ||
Role string | ||
Message string | ||
} | ||
|
||
// Prompt type for completion API. | ||
type Prompt struct { | ||
Prompt string `json:"prompt"` | ||
Temperature float32 `json:"temperature"` | ||
MaxTokens int32 `json:"maxTokens"` | ||
TopP float32 `json:"topP"` | ||
N int32 `json:"n"` | ||
PresencePenalty float32 `json:"presencePenalty"` | ||
FrequencyPenalty float32 `json:"frequencyPenalty"` | ||
} | ||
|
||
// NewOpenAI returns a new OpenAI output binding. | ||
func NewOpenAI(logger logger.Logger) bindings.OutputBinding { | ||
return &AzOpenAI{ | ||
logger: logger, | ||
} | ||
} | ||
|
||
// Init initializes the OpenAI binding. | ||
func (p *AzOpenAI) Init(ctx context.Context, meta bindings.Metadata) error { | ||
m := openAIMetadata{} | ||
err := metadata.DecodeMetadata(meta.Properties, &m) | ||
if err != nil { | ||
return fmt.Errorf("error decoding metadata: %w", err) | ||
} | ||
if m.Endpoint == "" { | ||
return fmt.Errorf("required metadata not set: %s", Endpoint) | ||
} | ||
if m.DeploymentID == "" { | ||
return fmt.Errorf("required metadata not set: %s", DeploymentID) | ||
} | ||
|
||
if m.APIKey != "" { | ||
// use API key authentication | ||
var keyCredential azopenai.KeyCredential | ||
keyCredential, err = azopenai.NewKeyCredential(m.APIKey) | ||
if err != nil { | ||
return fmt.Errorf("error getting credentials object: %w", err) | ||
} | ||
|
||
p.client, err = azopenai.NewClientWithKeyCredential(m.Endpoint, keyCredential, m.DeploymentID, nil) | ||
if err != nil { | ||
return fmt.Errorf("error creating Azure OpenAI client: %w", err) | ||
} | ||
} else { | ||
// fallback to Azure AD authentication | ||
settings, innerErr := azauth.NewEnvironmentSettings(meta.Properties) | ||
if innerErr != nil { | ||
return fmt.Errorf("error creating environment settings: %w", innerErr) | ||
} | ||
|
||
token, innerErr := settings.GetTokenCredential() | ||
if innerErr != nil { | ||
return fmt.Errorf("error getting token credential: %w", innerErr) | ||
} | ||
|
||
p.client, err = azopenai.NewClient(m.Endpoint, token, m.DeploymentID, nil) | ||
if err != nil { | ||
return fmt.Errorf("error creating Azure OpenAI client: %w", err) | ||
} | ||
} | ||
|
||
return nil | ||
} | ||
|
||
// Operations returns list of operations supported by OpenAI binding. | ||
func (p *AzOpenAI) Operations() []bindings.OperationKind { | ||
return []bindings.OperationKind{ | ||
ChatCompletionOperation, | ||
CompletionOperation, | ||
} | ||
} | ||
|
||
// Invoke handles all invoke operations. | ||
func (p *AzOpenAI) Invoke(ctx context.Context, req *bindings.InvokeRequest) (resp *bindings.InvokeResponse, err error) { | ||
if req == nil || len(req.Metadata) == 0 { | ||
return nil, fmt.Errorf("invalid request: metadata is required") | ||
} | ||
|
||
startTime := time.Now().UTC() | ||
resp = &bindings.InvokeResponse{ | ||
Metadata: map[string]string{ | ||
"operation": string(req.Operation), | ||
"start-time": startTime.Format(time.RFC3339Nano), | ||
}, | ||
} | ||
|
||
switch req.Operation { //nolint:exhaustive | ||
case CompletionOperation: | ||
response, err := p.completion(ctx, req.Data, req.Metadata) | ||
if err != nil { | ||
return nil, fmt.Errorf("error performing completion: %w", err) | ||
} | ||
responseAsBytes, _ := json.Marshal(response) | ||
resp.Data = responseAsBytes | ||
|
||
case ChatCompletionOperation: | ||
response, err := p.chatCompletion(ctx, req.Data, req.Metadata) | ||
if err != nil { | ||
return nil, fmt.Errorf("error performing chat completion: %w", err) | ||
} | ||
responseAsBytes, _ := json.Marshal(response) | ||
resp.Data = responseAsBytes | ||
|
||
default: | ||
return nil, fmt.Errorf( | ||
"invalid operation type: %s. Expected %s, %s", | ||
req.Operation, CompletionOperation, ChatCompletionOperation, | ||
) | ||
} | ||
|
||
endTime := time.Now().UTC() | ||
resp.Metadata["end-time"] = endTime.Format(time.RFC3339Nano) | ||
resp.Metadata["duration"] = endTime.Sub(startTime).String() | ||
|
||
return resp, nil | ||
} | ||
|
||
func (s *ChatSettings) Decode(in any) error { | ||
return config.Decode(in, s) | ||
} | ||
|
||
func (p *AzOpenAI) completion(ctx context.Context, message []byte, metadata map[string]string) (response []azopenai.Choice, err error) { | ||
prompt := Prompt{ | ||
Temperature: 1.0, | ||
TopP: 1.0, | ||
MaxTokens: 16, | ||
N: 1, | ||
PresencePenalty: 0.0, | ||
FrequencyPenalty: 0.0, | ||
} | ||
err = json.Unmarshal(message, &prompt) | ||
if err != nil { | ||
return nil, fmt.Errorf("error unmarshalling the input object: %w", err) | ||
} | ||
|
||
if prompt.Prompt == "" { | ||
return nil, fmt.Errorf("prompt is required for completion operation") | ||
} | ||
|
||
resp, err := p.client.GetCompletions(ctx, azopenai.CompletionsOptions{ | ||
Prompt: []*string{&prompt.Prompt}, | ||
MaxTokens: &prompt.MaxTokens, | ||
Temperature: &prompt.Temperature, | ||
TopP: &prompt.TopP, | ||
N: &prompt.N, | ||
}, nil) | ||
if err != nil { | ||
return nil, fmt.Errorf("error getting completion api: %w", err) | ||
} | ||
|
||
// No choices returned | ||
if len(resp.Completions.Choices) == 0 { | ||
return []azopenai.Choice{}, nil | ||
} | ||
|
||
choices := resp.Completions.Choices | ||
response = make([]azopenai.Choice, len(choices)) | ||
for i, c := range choices { | ||
response[i] = *c | ||
} | ||
|
||
return response, nil | ||
} | ||
|
||
func (p *AzOpenAI) chatCompletion(ctx context.Context, messageRequest []byte, metadata map[string]string) (response []azopenai.ChatChoice, err error) { | ||
messages := ChatMessages{ | ||
Temperature: 1.0, | ||
TopP: 1.0, | ||
N: 1, | ||
PresencePenalty: 0.0, | ||
FrequencyPenalty: 0.0, | ||
} | ||
err = json.Unmarshal(messageRequest, &messages) | ||
if err != nil { | ||
return nil, fmt.Errorf("error unmarshalling the input object: %w", err) | ||
} | ||
|
||
if len(messages.Messages) == 0 { | ||
return nil, fmt.Errorf("messages are required for chat-completion operation") | ||
} | ||
|
||
messageReq := make([]*azopenai.ChatMessage, len(messages.Messages)) | ||
for i, m := range messages.Messages { | ||
messageReq[i] = &azopenai.ChatMessage{ | ||
Role: to.Ptr(azopenai.ChatRole(m.Role)), | ||
Content: to.Ptr(m.Message), | ||
} | ||
} | ||
|
||
var maxTokens *int32 | ||
if messages.MaxTokens != 0 { | ||
maxTokens = &messages.MaxTokens | ||
} | ||
|
||
res, err := p.client.GetChatCompletions(ctx, azopenai.ChatCompletionsOptions{ | ||
MaxTokens: maxTokens, | ||
Temperature: &messages.Temperature, | ||
TopP: &messages.TopP, | ||
N: &messages.N, | ||
Messages: messageReq, | ||
}, nil) | ||
if err != nil { | ||
return nil, fmt.Errorf("error getting chat completion api: %w", err) | ||
} | ||
|
||
// No choices returned. | ||
if len(res.ChatCompletions.Choices) == 0 { | ||
return []azopenai.ChatChoice{}, nil | ||
} | ||
|
||
choices := res.ChatCompletions.Choices | ||
response = make([]azopenai.ChatChoice, len(choices)) | ||
for i, c := range choices { | ||
response[i] = *c | ||
} | ||
|
||
return response, nil | ||
} | ||
|
||
// Close Az OpenAI instance. | ||
func (p *AzOpenAI) Close() error { | ||
p.client = nil | ||
|
||
return nil | ||
} | ||
|
||
// GetComponentMetadata returns the metadata of the component. | ||
func (p *AzOpenAI) GetComponentMetadata() map[string]string { | ||
metadataStruct := openAIMetadata{} | ||
metadataInfo := map[string]string{} | ||
metadata.GetMetadataInfoFromStructType(reflect.TypeOf(metadataStruct), &metadataInfo, metadata.BindingType) | ||
return metadataInfo | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.