Skip to content

Commit

Permalink
Added stop property
Browse files Browse the repository at this point in the history
Signed-off-by: ItalyPaleAle <[email protected]>
  • Loading branch information
ItalyPaleAle committed Jul 21, 2023
1 parent 6d21f17 commit a8d6892
Showing 1 changed file with 19 additions and 7 deletions.
26 changes: 19 additions & 7 deletions bindings/azure/openai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ type ChatMessages struct {
N int32 `json:"n"`
PresencePenalty float32 `json:"presencePenalty"`
FrequencyPenalty float32 `json:"frequencyPenalty"`
Stop []string `json:"stop"`
}

// Message type stores the messages for bot conversation.
Expand All @@ -83,13 +84,14 @@ type Message struct {

// Prompt type for completion API.
type Prompt struct {
Prompt string `json:"prompt"`
Temperature float32 `json:"temperature"`
MaxTokens int32 `json:"maxTokens"`
TopP float32 `json:"topP"`
N int32 `json:"n"`
PresencePenalty float32 `json:"presencePenalty"`
FrequencyPenalty float32 `json:"frequencyPenalty"`
Prompt string `json:"prompt"`
Temperature float32 `json:"temperature"`
MaxTokens int32 `json:"maxTokens"`
TopP float32 `json:"topP"`
N int32 `json:"n"`
PresencePenalty float32 `json:"presencePenalty"`
FrequencyPenalty float32 `json:"frequencyPenalty"`
Stop []string `json:"stop"`
}

// NewOpenAI returns a new OpenAI output binding.
Expand Down Expand Up @@ -218,13 +220,18 @@ func (p *AzOpenAI) completion(ctx context.Context, message []byte, metadata map[
return nil, fmt.Errorf("prompt is required for completion operation")
}

if len(prompt.Stop) == 0 {
prompt.Stop = nil
}

resp, err := p.client.GetCompletions(ctx, azopenai.CompletionsOptions{
DeploymentID: p.deploymentID,
Prompt: []string{prompt.Prompt},
MaxTokens: &prompt.MaxTokens,
Temperature: &prompt.Temperature,
TopP: &prompt.TopP,
N: &prompt.N,
Stop: prompt.Stop,
}, nil)
if err != nil {
return nil, fmt.Errorf("error getting completion api: %w", err)
Expand Down Expand Up @@ -261,6 +268,10 @@ func (p *AzOpenAI) chatCompletion(ctx context.Context, messageRequest []byte, me
return nil, fmt.Errorf("messages are required for chat-completion operation")
}

if len(messages.Stop) == 0 {
messages.Stop = nil
}

messageReq := make([]azopenai.ChatMessage, len(messages.Messages))
for i, m := range messages.Messages {
messageReq[i] = azopenai.ChatMessage{
Expand All @@ -281,6 +292,7 @@ func (p *AzOpenAI) chatCompletion(ctx context.Context, messageRequest []byte, me
TopP: &messages.TopP,
N: &messages.N,
Messages: messageReq,
Stop: messages.Stop,
}, nil)
if err != nil {
return nil, fmt.Errorf("error getting chat completion api: %w", err)
Expand Down

0 comments on commit a8d6892

Please sign in to comment.