From a8d68923492609a823255ee46075e08c5de42ebd Mon Sep 17 00:00:00 2001 From: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com> Date: Fri, 21 Jul 2023 07:15:36 -0700 Subject: [PATCH] Added stop property Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com> --- bindings/azure/openai/openai.go | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/bindings/azure/openai/openai.go b/bindings/azure/openai/openai.go index 2dfe039f23..81743b5c81 100644 --- a/bindings/azure/openai/openai.go +++ b/bindings/azure/openai/openai.go @@ -73,6 +73,7 @@ type ChatMessages struct { N int32 `json:"n"` PresencePenalty float32 `json:"presencePenalty"` FrequencyPenalty float32 `json:"frequencyPenalty"` + Stop []string `json:"stop"` } // Message type stores the messages for bot conversation. @@ -83,13 +84,14 @@ type Message struct { // Prompt type for completion API. type Prompt struct { - Prompt string `json:"prompt"` - Temperature float32 `json:"temperature"` - MaxTokens int32 `json:"maxTokens"` - TopP float32 `json:"topP"` - N int32 `json:"n"` - PresencePenalty float32 `json:"presencePenalty"` - FrequencyPenalty float32 `json:"frequencyPenalty"` + Prompt string `json:"prompt"` + Temperature float32 `json:"temperature"` + MaxTokens int32 `json:"maxTokens"` + TopP float32 `json:"topP"` + N int32 `json:"n"` + PresencePenalty float32 `json:"presencePenalty"` + FrequencyPenalty float32 `json:"frequencyPenalty"` + Stop []string `json:"stop"` } // NewOpenAI returns a new OpenAI output binding. @@ -218,6 +220,10 @@ func (p *AzOpenAI) completion(ctx context.Context, message []byte, metadata map[ return nil, fmt.Errorf("prompt is required for completion operation") } + if len(prompt.Stop) == 0 { + prompt.Stop = nil + } + resp, err := p.client.GetCompletions(ctx, azopenai.CompletionsOptions{ DeploymentID: p.deploymentID, Prompt: []string{prompt.Prompt}, @@ -225,6 +231,7 @@ func (p *AzOpenAI) completion(ctx context.Context, message []byte, metadata map[ Temperature: &prompt.Temperature, TopP: &prompt.TopP, N: &prompt.N, + Stop: prompt.Stop, }, nil) if err != nil { return nil, fmt.Errorf("error getting completion api: %w", err) @@ -261,6 +268,10 @@ func (p *AzOpenAI) chatCompletion(ctx context.Context, messageRequest []byte, me return nil, fmt.Errorf("messages are required for chat-completion operation") } + if len(messages.Stop) == 0 { + messages.Stop = nil + } + messageReq := make([]azopenai.ChatMessage, len(messages.Messages)) for i, m := range messages.Messages { messageReq[i] = azopenai.ChatMessage{ @@ -281,6 +292,7 @@ func (p *AzOpenAI) chatCompletion(ctx context.Context, messageRequest []byte, me TopP: &messages.TopP, N: &messages.N, Messages: messageReq, + Stop: messages.Stop, }, nil) if err != nil { return nil, fmt.Errorf("error getting chat completion api: %w", err)