From 8b6d39f5c1d1109972e4c2e9a44153d39b4ba1bf Mon Sep 17 00:00:00 2001 From: yaron2 Date: Fri, 11 Oct 2024 14:17:49 -0700 Subject: [PATCH] add anthropic Signed-off-by: yaron2 --- conversation/anthropic/anthropic.go | 118 +++++++++++++++++++++++ conversation/anthropic/metadata.yaml | 29 ++++++ conversation/aws/bedrock/bedrock.go | 19 +--- conversation/aws/bedrock/bedrock_test.go | 25 ----- conversation/langchainroles.go | 34 +++++++ conversation/langchainroles_test.go | 39 ++++++++ 6 files changed, 221 insertions(+), 43 deletions(-) create mode 100644 conversation/anthropic/anthropic.go create mode 100644 conversation/anthropic/metadata.yaml delete mode 100644 conversation/aws/bedrock/bedrock_test.go create mode 100644 conversation/langchainroles.go create mode 100644 conversation/langchainroles_test.go diff --git a/conversation/anthropic/anthropic.go b/conversation/anthropic/anthropic.go new file mode 100644 index 0000000000..344c458d03 --- /dev/null +++ b/conversation/anthropic/anthropic.go @@ -0,0 +1,118 @@ +/* +Copyright 2024 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package anthropic + +import ( + "context" + "reflect" + + "github.com/dapr/components-contrib/conversation" + "github.com/dapr/components-contrib/metadata" + "github.com/dapr/kit/logger" + kmeta "github.com/dapr/kit/metadata" + + "github.com/tmc/langchaingo/llms" + "github.com/tmc/langchaingo/llms/anthropic" +) + +type Anthropic struct { + llm *anthropic.LLM + + logger logger.Logger +} + +type AnthropicMetadata struct { + Key string `json:"key"` + Model string `json:"model"` +} + +func NewAnthropic(logger logger.Logger) conversation.Conversation { + a := &Anthropic{ + logger: logger, + } + + return a +} + +const defaultModel = "claude-3-5-sonnet-20240620" + +func (a *Anthropic) Init(ctx context.Context, meta conversation.Metadata) error { + m := AnthropicMetadata{} + err := kmeta.DecodeMetadata(meta.Properties, &m) + if err != nil { + return err + } + + model := defaultModel + if m.Model != "" { + model = m.Model + } + + llm, err := anthropic.New( + anthropic.WithModel(model), + anthropic.WithToken(m.Key), + ) + if err != nil { + return err + } + + a.llm = llm + return nil +} + +func (a *Anthropic) GetComponentMetadata() (metadataInfo metadata.MetadataMap) { + metadataStruct := AnthropicMetadata{} + metadata.GetMetadataInfoFromStructType(reflect.TypeOf(metadataStruct), &metadataInfo, metadata.ConversationType) + return +} + +func (a *Anthropic) Converse(ctx context.Context, r *conversation.ConversationRequest) (res *conversation.ConversationResponse, err error) { + messages := make([]llms.MessageContent, 0, len(r.Inputs)) + + for _, input := range r.Inputs { + role := conversation.ConvertLangchainRole(input.Role) + + messages = append(messages, llms.MessageContent{ + Role: role, + Parts: []llms.ContentPart{ + llms.TextPart(input.Message), + }, + }) + } + + resp, err := a.llm.GenerateContent(ctx, messages) + if err != nil { + return nil, err + } + + outputs := make([]conversation.ConversationResult, 0, len(resp.Choices)) + + for i := range resp.Choices { + outputs = append(outputs, conversation.ConversationResult{ + Result: resp.Choices[i].Content, + Parameters: r.Parameters, + }) + } + + res = &conversation.ConversationResponse{ + Outputs: outputs, + } + + return res, nil +} + +func (a *Anthropic) Close() error { + return nil +} diff --git a/conversation/anthropic/metadata.yaml b/conversation/anthropic/metadata.yaml new file mode 100644 index 0000000000..e8e6587a40 --- /dev/null +++ b/conversation/anthropic/metadata.yaml @@ -0,0 +1,29 @@ +# yaml-language-server: $schema=../../../component-metadata-schema.json +schemaVersion: v1 +type: conversation +name: anthropic +version: v1 +status: alpha +title: "Anthropic" +urls: + - title: Reference + url: https://docs.dapr.io/reference/components-reference/supported-conversation/setup-anthropic/ +authenticationProfiles: + - title: "API Key" + description: "Authenticate using an API key" + metadata: + - name: key + type: string + required: true + sensitive: true + description: | + API key for Anthropic. + example: "**********" + default: "" +metadata: + - name: model + required: false + description: | + The Anthropic LLM to use. Defaults to claude-3-5-sonnet-20240620 + type: string + example: 'claude-3-5-sonnet-20240620' diff --git a/conversation/aws/bedrock/bedrock.go b/conversation/aws/bedrock/bedrock.go index 9fbf805d5f..709e74da00 100644 --- a/conversation/aws/bedrock/bedrock.go +++ b/conversation/aws/bedrock/bedrock.go @@ -53,23 +53,6 @@ func NewAWSBedrock(logger logger.Logger) conversation.Conversation { return b } -func convertRole(role conversation.Role) llms.ChatMessageType { - switch role { - case conversation.RoleSystem: - return llms.ChatMessageTypeSystem - case conversation.RoleUser: - return llms.ChatMessageTypeHuman - case conversation.RoleAssistant: - return llms.ChatMessageTypeAI - case conversation.RoleTool: - return llms.ChatMessageTypeTool - case conversation.RoleFunction: - return llms.ChatMessageTypeFunction - default: - return llms.ChatMessageTypeHuman - } -} - func (b *AWSBedrock) Init(ctx context.Context, meta conversation.Metadata) error { m := AWSBedrockMetadata{} err := kmeta.DecodeMetadata(meta.Properties, &m) @@ -111,7 +94,7 @@ func (b *AWSBedrock) Converse(ctx context.Context, r *conversation.ConversationR messages := make([]llms.MessageContent, 0, len(r.Inputs)) for _, input := range r.Inputs { - role := convertRole(input.Role) + role := conversation.ConvertLangchainRole(input.Role) messages = append(messages, llms.MessageContent{ Role: role, diff --git a/conversation/aws/bedrock/bedrock_test.go b/conversation/aws/bedrock/bedrock_test.go deleted file mode 100644 index 63bda0187f..0000000000 --- a/conversation/aws/bedrock/bedrock_test.go +++ /dev/null @@ -1,25 +0,0 @@ -package bedrock - -import ( - "testing" - - "github.com/dapr/components-contrib/conversation" - - "github.com/stretchr/testify/assert" - "github.com/tmc/langchaingo/llms" -) - -func TestConvertRole(t *testing.T) { - roles := map[string]string{ - conversation.RoleSystem: string(llms.ChatMessageTypeSystem), - conversation.RoleAssistant: string(llms.ChatMessageTypeAI), - conversation.RoleFunction: string(llms.ChatMessageTypeFunction), - conversation.RoleUser: string(llms.ChatMessageTypeHuman), - conversation.RoleTool: string(llms.ChatMessageTypeTool), - } - - for k, v := range roles { - r := convertRole(conversation.Role(k)) - assert.Equal(t, v, string(r)) - } -} diff --git a/conversation/langchainroles.go b/conversation/langchainroles.go new file mode 100644 index 0000000000..a3d4a60fd3 --- /dev/null +++ b/conversation/langchainroles.go @@ -0,0 +1,34 @@ +/* +Copyright 2024 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package conversation + +import "github.com/tmc/langchaingo/llms" + +func ConvertLangchainRole(role Role) llms.ChatMessageType { + switch role { + case RoleSystem: + return llms.ChatMessageTypeSystem + case RoleUser: + return llms.ChatMessageTypeHuman + case RoleAssistant: + return llms.ChatMessageTypeAI + case RoleTool: + return llms.ChatMessageTypeTool + case RoleFunction: + return llms.ChatMessageTypeFunction + default: + return llms.ChatMessageTypeHuman + } +} diff --git a/conversation/langchainroles_test.go b/conversation/langchainroles_test.go new file mode 100644 index 0000000000..5f83de34ba --- /dev/null +++ b/conversation/langchainroles_test.go @@ -0,0 +1,39 @@ +/* +Copyright 2024 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package conversation + +import ( + "testing" + + "github.com/tmc/langchaingo/llms" + + "github.com/stretchr/testify/assert" +) + +func TestConvertLangchainRole(t *testing.T) { + roles := map[string]string{ + RoleSystem: string(llms.ChatMessageTypeSystem), + RoleAssistant: string(llms.ChatMessageTypeAI), + RoleFunction: string(llms.ChatMessageTypeFunction), + RoleUser: string(llms.ChatMessageTypeHuman), + RoleTool: string(llms.ChatMessageTypeTool), + } + + for k, v := range roles { + r := ConvertLangchainRole(Role(k)) + assert.Equal(t, v, string(r)) + } +}