diff --git a/conversation/aws/bedrock/bedrock.go b/conversation/aws/bedrock/bedrock.go index 1bc7b6d2d4..9fbf805d5f 100644 --- a/conversation/aws/bedrock/bedrock.go +++ b/conversation/aws/bedrock/bedrock.go @@ -88,6 +88,7 @@ func (b *AWSBedrock) Init(ctx context.Context, meta conversation.Metadata) error if m.Model != "" { opts = append(opts, bedrock.WithModel(m.Model)) } + b.model = m.Model llm, err := bedrock.New( opts..., diff --git a/conversation/aws/bedrock/bedrock_test.go b/conversation/aws/bedrock/bedrock_test.go index 91e8ec9fb2..63bda0187f 100644 --- a/conversation/aws/bedrock/bedrock_test.go +++ b/conversation/aws/bedrock/bedrock_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/dapr/components-contrib/conversation" + "github.com/stretchr/testify/assert" "github.com/tmc/langchaingo/llms" ) diff --git a/conversation/echo/echo_test.go b/conversation/echo/echo_test.go index c7a558ffca..77d8cc57c4 100644 --- a/conversation/echo/echo_test.go +++ b/conversation/echo/echo_test.go @@ -6,7 +6,9 @@ import ( "github.com/dapr/components-contrib/conversation" "github.com/dapr/kit/logger" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestConverse(t *testing.T) { @@ -20,7 +22,7 @@ func TestConverse(t *testing.T) { }, }, }) - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, r.Outputs, 1) assert.Equal(t, "hello", r.Outputs[0].Result) } diff --git a/conversation/openai/openai_test.go b/conversation/openai/openai_test.go index 8262f32522..19b0ab085b 100644 --- a/conversation/openai/openai_test.go +++ b/conversation/openai/openai_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/dapr/components-contrib/conversation" + "github.com/stretchr/testify/assert" openai "github.com/sashabaranov/go-openai" @@ -20,6 +21,6 @@ func TestConvertRole(t *testing.T) { for k, v := range roles { r := convertRole(conversation.Role(k)) - assert.Equal(t, v, string(r)) + assert.Equal(t, v, r) } } diff --git a/tests/conformance/conversation/conversation.go b/tests/conformance/conversation/conversation.go index 287cd14cbe..e909c36061 100644 --- a/tests/conformance/conversation/conversation.go +++ b/tests/conformance/conversation/conversation.go @@ -17,16 +17,12 @@ import ( "context" "testing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/dapr/components-contrib/conversation" "github.com/dapr/components-contrib/metadata" "github.com/dapr/components-contrib/tests/conformance/utils" -) -const ( - conversationComponent = "echo" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) type TestConfig struct {