Skip to content

Commit

Permalink
Merge branch 'master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
ItalyPaleAle authored Aug 7, 2023
2 parents d2cc190 + c957420 commit 0cbcd2a
Show file tree
Hide file tree
Showing 83 changed files with 2,199 additions and 381 deletions.
69 changes: 69 additions & 0 deletions .build-tools/builtin-authentication-profiles.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,72 @@ azuread:
- AzurePublicCloud
- AzureChinaCloud
- AzureUSGovernmentCloud

gcp:
- title: "GCP API Authentication with Service Account Key"
description: |
Authenticate authenticates API calls with the given service account or refresh token JSON credentials.
metadata:
- name: privateKeyID
required: true
sensitive: true
description: |
The GCP private key id. Replace with the value of "private_key_id" field of the Service Account Key file.
example: '"privateKeyID"'
- name: privateKey
required: true
sensitive: true
description: |
The GCP credentials private key. Replace with the value of "private_key" field of the Service Account Key file.
example: '"-----BEGIN PRIVATE KEY-----\nMIIE...\\n-----END PRIVATE KEY-----\n"'
- name: type
type: string
required: false
description: |
The GCP credentials type.
example: '"service_account"'
allowedValues:
- service_account
- name: projectID
type: string
required: true
description: |
GCP project id.
example: '"projectID"'
- name: clientEmail
type: string
required: true
description: |
GCP client email.
example: '"[email protected]"'
- name: clientID
type: string
required: true
description: |
The GCP client ID.
example: '"0123456789-0123456789"'
- name: authURI
type: string
required: false
description: |
The GCP account OAuth2 authorization server endpoint URI.
example: '"https://accounts.google.com/o/oauth2/auth"'
- name: tokenURI
type: string
required: false
description: |
The GCP account token server endpoint URI.
example: '"https://oauth2.googleapis.com/token"'
- name: authProviderX509CertURL
type: string
required: false
description: |
The GCP URL of the public x509 certificate, used to verify the signature
on JWTs, such as ID tokens, signed by the authentication provider.
example: '"https://www.googleapis.com/oauth2/v1/certs"'
- name: clientX509CertURL
type: string
required: false
description: |
The GCP URL of the public x509 certificate, used to verify JWTs signed by the client.
example: '"https://www.googleapis.com/robot/v1/metadata/x509/<PROJECT_NAME>.iam.gserviceaccount.com"'
2 changes: 1 addition & 1 deletion .build-tools/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ require (
github.com/invopop/jsonschema v0.6.0
github.com/spf13/cobra v1.6.1
github.com/xeipuuv/gojsonschema v1.2.1-0.20201027075954-b076d39a02e5
golang.org/x/exp v0.0.0-20230711153332-06a737ee72cb
golang.org/x/exp v0.0.0-20230801115018-d63ba01acd4b
gopkg.in/yaml.v3 v3.0.1
sigs.k8s.io/yaml v1.3.0
)
Expand Down
4 changes: 2 additions & 2 deletions .build-tools/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 h1:EzJWgHo
github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ=
github.com/xeipuuv/gojsonschema v1.2.1-0.20201027075954-b076d39a02e5 h1:ImnGIsrcG8vwbovhYvvSY8fagVV6QhCWSWXfzwGDLVs=
github.com/xeipuuv/gojsonschema v1.2.1-0.20201027075954-b076d39a02e5/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y=
golang.org/x/exp v0.0.0-20230711153332-06a737ee72cb h1:xIApU0ow1zwMa2uL1VDNeQlNVFTWMQxZUZCMDy0Q4Us=
golang.org/x/exp v0.0.0-20230711153332-06a737ee72cb/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
golang.org/x/exp v0.0.0-20230801115018-d63ba01acd4b h1:r+vk0EmXNmekl0S0BascoeeoHk/L7wmaW2QF90K+kYI=
golang.org/x/exp v0.0.0-20230801115018-d63ba01acd4b/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ resource "aws_dynamodb_table" "conformance_test_basic_table" {
billing_mode = "PROVISIONED"
read_capacity = "10"
write_capacity = "10"
ttl {
attribute_name = "expiresAt"
enabled = true
}
attribute {
name = "key"
type = "S"
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/certification.yml
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ jobs:
set +e
gotestsum --jsonfile ${{ env.TEST_OUTPUT_FILE_PREFIX }}_certification.json \
--junitfile ${{ env.TEST_OUTPUT_FILE_PREFIX }}_certification.xml --format standard-quiet -- \
-coverprofile=cover.out -covermode=set -tags=certtests -coverpkg=${{ matrix.source-pkg }}
-coverprofile=cover.out -covermode=set -tags=certtests -timeout=30m -coverpkg=${{ matrix.source-pkg }}
status=$?
echo "Completed certification tests for ${{ matrix.component }} ... "
if test $status -ne 0; then
Expand Down
10 changes: 10 additions & 0 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ linters-settings:
- "golang.org/x/net/context": "must use context"
- "github.com/pkg/errors": "must use standard library (errors package and/or fmt.Errorf)"
- "github.com/Sirupsen/logrus": "must use github.com/dapr/kit/logger"
- "github.com/labstack/gommon/log": "must use github.com/dapr/kit/logger"
- "github.com/gobuffalo/logger": "must use github.com/dapr/kit/logger"
- "github.com/agrea/ptr": "must use github.com/dapr/kit/ptr"
- "github.com/cenkalti/backoff$": "must use github.com/cenkalti/backoff/v4"
- "github.com/cenkalti/backoff/v2": "must use github.com/cenkalti/backoff/v4"
Expand All @@ -133,6 +135,14 @@ linters-settings:
- "github.com/golang-jwt/jwt/v4": "must use github.com/lestrrat-go/jwx/v2"
- "github.com/lestrrat-go/jwx/jwa": "must use github.com/lestrrat-go/jwx/v2"
- "github.com/lestrrat-go/jwx/jwt": "must use github.com/lestrrat-go/jwx/v2"
- "github.com/lestrrat-go/jwx/jws": "must use github.com/lestrrat-go/jwx/v2"
- "github.com/gogo/status": "must use google.golang.org/grpc/status"
- "github.com/gogo/protobuf": "must use google.golang.org/protobuf"
- "k8s.io/utils/pointer": "must use github.com/dapr/kit/ptr"
- "k8s.io/utils/ptr": "must use github.com/dapr/kit/ptr"
- "github.com/ghodss/yaml": "must use sigs.k8s.io/yaml"
- "gopkg.in/yaml.v2": "must use gopkg.in/yaml.v3"
- "github.com/go-chi/chi$": "must use github.com/go-chi/chi/v5"
misspell:
# Correct spellings using locale preferences for US or UK.
# Default is to use a neutral variety of English.
Expand Down
6 changes: 1 addition & 5 deletions bindings/azure/openai/metadata.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,4 @@ metadata:
- name: endpoint
required: true
description: "Endpoint of the Azure OpenAI service"
example: '"https://myopenai.openai.azure.com"'
- name: deploymentID
required: true
description: "ID of the model deployment in the Azure OpenAI service"
example: '"my-model"'
example: '"https://myopenai.openai.azure.com"'
70 changes: 58 additions & 12 deletions bindings/azure/openai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ import (
"reflect"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/cognitiveservices/azopenai"

"github.com/dapr/components-contrib/bindings"
azauth "github.com/dapr/components-contrib/internal/authentication/azure"
Expand All @@ -33,6 +33,7 @@ import (
const (
CompletionOperation bindings.OperationKind = "completion"
ChatCompletionOperation bindings.OperationKind = "chat-completion"
GetEmbeddingOperation bindings.OperationKind = "get-embedding"

APIKey = "apiKey"
DeploymentID = "deploymentID"
Expand All @@ -50,22 +51,20 @@ const (

// AzOpenAI represents OpenAI output binding.
type AzOpenAI struct {
logger logger.Logger
client *azopenai.Client
deploymentID string
logger logger.Logger
client *azopenai.Client
}

type openAIMetadata struct {
// APIKey is the API key for the Azure OpenAI API.
APIKey string `mapstructure:"apiKey"`
// DeploymentID is the deployment ID for the Azure OpenAI API.
DeploymentID string `mapstructure:"deploymentID"`
// Endpoint is the endpoint for the Azure OpenAI API.
Endpoint string `mapstructure:"endpoint"`
}

// ChatMessages type for chat completion API.
type ChatMessages struct {
DeploymentID string `json:"deploymentID"`
Messages []Message `json:"messages"`
Temperature float32 `json:"temperature"`
MaxTokens int32 `json:"maxTokens"`
Expand All @@ -84,6 +83,7 @@ type Message struct {

// Prompt type for completion API.
type Prompt struct {
DeploymentID string `json:"deploymentID"`
Prompt string `json:"prompt"`
Temperature float32 `json:"temperature"`
MaxTokens int32 `json:"maxTokens"`
Expand All @@ -94,6 +94,11 @@ type Prompt struct {
Stop []string `json:"stop"`
}

type EmbeddingMessage struct {
DeploymentID string `json:"deploymentID"`
Message string `json:"message"`
}

// NewOpenAI returns a new OpenAI output binding.
func NewOpenAI(logger logger.Logger) bindings.OutputBinding {
return &AzOpenAI{
Expand All @@ -111,9 +116,6 @@ func (p *AzOpenAI) Init(ctx context.Context, meta bindings.Metadata) error {
if m.Endpoint == "" {
return fmt.Errorf("required metadata not set: %s", Endpoint)
}
if m.DeploymentID == "" {
return fmt.Errorf("required metadata not set: %s", DeploymentID)
}

if m.APIKey != "" {
// use API key authentication
Expand Down Expand Up @@ -144,7 +146,6 @@ func (p *AzOpenAI) Init(ctx context.Context, meta bindings.Metadata) error {
return fmt.Errorf("error creating Azure OpenAI client: %w", err)
}
}
p.deploymentID = m.DeploymentID

return nil
}
Expand All @@ -154,6 +155,7 @@ func (p *AzOpenAI) Operations() []bindings.OperationKind {
return []bindings.OperationKind{
ChatCompletionOperation,
CompletionOperation,
GetEmbeddingOperation,
}
}

Expand Down Expand Up @@ -188,6 +190,14 @@ func (p *AzOpenAI) Invoke(ctx context.Context, req *bindings.InvokeRequest) (res
responseAsBytes, _ := json.Marshal(response)
resp.Data = responseAsBytes

case GetEmbeddingOperation:
response, err := p.getEmbedding(ctx, req.Data, req.Metadata)
if err != nil {
return nil, fmt.Errorf("error performing get embedding operation: %w", err)
}
responseAsBytes, _ := json.Marshal(response)
resp.Data = responseAsBytes

default:
return nil, fmt.Errorf(
"invalid operation type: %s. Expected %s, %s",
Expand Down Expand Up @@ -220,12 +230,16 @@ func (p *AzOpenAI) completion(ctx context.Context, message []byte, metadata map[
return nil, fmt.Errorf("prompt is required for completion operation")
}

if prompt.DeploymentID == "" {
return nil, fmt.Errorf("required metadata not set: %s", DeploymentID)
}

if len(prompt.Stop) == 0 {
prompt.Stop = nil
}

resp, err := p.client.GetCompletions(ctx, azopenai.CompletionsOptions{
DeploymentID: p.deploymentID,
DeploymentID: prompt.DeploymentID,
Prompt: []string{prompt.Prompt},
MaxTokens: &prompt.MaxTokens,
Temperature: &prompt.Temperature,
Expand Down Expand Up @@ -268,6 +282,10 @@ func (p *AzOpenAI) chatCompletion(ctx context.Context, messageRequest []byte, me
return nil, fmt.Errorf("messages are required for chat-completion operation")
}

if messages.DeploymentID == "" {
return nil, fmt.Errorf("required metadata not set: %s", DeploymentID)
}

if len(messages.Stop) == 0 {
messages.Stop = nil
}
Expand All @@ -286,7 +304,7 @@ func (p *AzOpenAI) chatCompletion(ctx context.Context, messageRequest []byte, me
}

res, err := p.client.GetChatCompletions(ctx, azopenai.ChatCompletionsOptions{
DeploymentID: p.deploymentID,
DeploymentID: messages.DeploymentID,
MaxTokens: maxTokens,
Temperature: &messages.Temperature,
TopP: &messages.TopP,
Expand All @@ -312,6 +330,34 @@ func (p *AzOpenAI) chatCompletion(ctx context.Context, messageRequest []byte, me
return response, nil
}

func (p *AzOpenAI) getEmbedding(ctx context.Context, messageRequest []byte, metadata map[string]string) (response []float32, err error) {
message := EmbeddingMessage{}
err = json.Unmarshal(messageRequest, &message)
if err != nil {
return nil, fmt.Errorf("error unmarshalling the input object: %w", err)
}

if message.DeploymentID == "" {
return nil, fmt.Errorf("required metadata not set: %s", DeploymentID)
}

res, err := p.client.GetEmbeddings(ctx, azopenai.EmbeddingsOptions{
DeploymentID: message.DeploymentID,
Input: []string{message.Message},
}, nil)
if err != nil {
return nil, fmt.Errorf("error getting embedding api: %w", err)
}

// No embedding returned.
if len(res.Data) == 0 {
return []float32{}, nil
}

response = res.Data[0].Embedding
return response, nil
}

// Close Az OpenAI instance.
func (p *AzOpenAI) Close() error {
p.client = nil
Expand Down
29 changes: 28 additions & 1 deletion bindings/azure/storagequeues/storagequeues.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ const (
defaultTTL = 10 * time.Minute
defaultVisibilityTimeout = 30 * time.Second
defaultPollingInterval = 10 * time.Second
dequeueCount = "dequeueCount"
insertionTime = "insertionTime"
expirationTime = "expirationTime"
nextVisibleTime = "nextVisibleTime"
popReceipt = "popReceipt"
messageID = "messageID"
)

type consumer struct {
Expand Down Expand Up @@ -177,9 +183,30 @@ func (d *AzureQueueHelper) Read(ctx context.Context, consumer *consumer) error {
}
}

metadata := make(map[string]string, 6)

if res.Messages[0].MessageID != nil {
metadata[messageID] = *res.Messages[0].MessageID
}
if res.Messages[0].PopReceipt != nil {
metadata[popReceipt] = *res.Messages[0].PopReceipt
}
if res.Messages[0].InsertionTime != nil {
metadata[insertionTime] = res.Messages[0].InsertionTime.Format(time.RFC3339)
}
if res.Messages[0].ExpirationTime != nil {
metadata[expirationTime] = res.Messages[0].ExpirationTime.Format(time.RFC3339)
}
if res.Messages[0].TimeNextVisible != nil {
metadata[nextVisibleTime] = res.Messages[0].TimeNextVisible.Format(time.RFC3339)
}
if res.Messages[0].DequeueCount != nil {
metadata[dequeueCount] = strconv.FormatInt(*res.Messages[0].DequeueCount, 10)
}

_, err = consumer.callback(ctx, &bindings.ReadResponse{
Data: data,
Metadata: map[string]string{},
Metadata: metadata,
})
if err != nil {
return err
Expand Down
Loading

0 comments on commit 0cbcd2a

Please sign in to comment.