From e9c7b4f2bc9f253a4ce7de7f46ff7e02b5c251f1 Mon Sep 17 00:00:00 2001 From: CloudBridge UY Date: Wed, 14 Jun 2023 10:05:42 -0300 Subject: [PATCH] feat(anthropic): add optional system prompt - Add `system` field to `Create` command - Handle system prompt by prepending to user prompt - Adjust prompt trimming to account for system prompt - Cleanup multiple newlines and extra spaces in prompts --- crates/anthropic/src/complete.rs | 37 ++++++++++++++++++++++++-------- crates/b/src/anthropic.rs | 5 +++++ crates/b/src/lib.rs | 6 ++++++ 3 files changed, 39 insertions(+), 9 deletions(-) diff --git a/crates/anthropic/src/complete.rs b/crates/anthropic/src/complete.rs index 6e2d4e0..ffd87fb 100644 --- a/crates/anthropic/src/complete.rs +++ b/crates/anthropic/src/complete.rs @@ -36,6 +36,7 @@ pub struct Api { // Complete Properties (https://console.anthropic.com/docs/api/reference) pub model: Model, pub prompt: String, + pub system: Option, #[serde(skip_serializing_if = "Option::is_none")] pub max_tokens_to_sample: Option, #[serde(skip_serializing_if = "Option::is_none")] @@ -176,10 +177,17 @@ impl Api { api.max_supported_tokens = None; api.session = None; - api.prompt = trim_prompt( - api.prompt.to_string(), - max_supported_tokens - max_tokens_to_sample, - )?; + let mut max = max_supported_tokens - max_tokens_to_sample; + + if let Some(system) = &api.system { + log::debug!("system: {:?}", system); + api.prompt = prepare_prompt(format!("{}\n{}", system, prompt)); + max -= token_length(system.to_string()) as u32 + } + + api.prompt = trim_prompt(api.prompt.to_string(), max)?; + + log::debug!("trimmed prompt: {}", api.prompt); let request = serde_json::to_string(api)?; @@ -270,19 +278,26 @@ pub fn serialize_sessions_file(session_file: &str, complete_api: &Api) -> Result Ok(()) } +/// Token language of a prompt. +/// TODO: Make this better! +fn token_length(prompt: String) -> usize { + let words = prompt.split_whitespace().rev().collect::>(); + + // Estimate the total tokens by multiplying words by 4/3 + words.len() * 4 / 3 +} + /// Trims the size of the prompt to match the max value. fn trim_prompt(prompt: String, max: u32) -> Result { let prompt = prepare_prompt(prompt); - - let mut words = prompt.split_whitespace().rev().collect::>(); - - // Estimate the total tokens by multiplying words by 4/3 - let tokens = words.len() * 4 / 3; + let tokens = token_length(prompt.clone()); if tokens as u32 <= max { return Ok(prompt); } + let mut words = prompt.split_whitespace().rev().collect::>(); + // Because we need to add back "\n\nHuman:" back to the prompt. let diff = words.len() - (max + 3) as usize; @@ -298,9 +313,13 @@ fn trim_prompt(prompt: String, max: u32) -> Result { fn prepare_prompt(prompt: String) -> String { let mut prompt = "\n\nHuman: ".to_string() + &prompt + "\n\nAssistant:"; + prompt = prompt.replace("\n\n\n", "\n\n"); + prompt = prompt.replace("Human: Human:", "Human:"); prompt = prompt.replace("\n\nHuman:\n\nHuman: ", "\n\nHuman: "); prompt = prompt.replace("\n\nHuman: \nHuman: ", "\n\nHuman: "); prompt = prompt.replace("\n\nHuman: \n\nHuman: ", "\n\nHuman: "); + prompt = prompt.replace("\n\nHuman:\n\nAssistant:", "\n\nAssistant:"); + prompt = prompt.replace("\n\nHuman: \n\nAssistant:", "\n\nAssistant:"); prompt = prompt.replace("\n\nAssistant:\n\nAssistant:", "\n\nAssistant:"); prompt = prompt.replace("\n\nAssistant: \n\nAssistant: ", "\n\nAssistant:"); diff --git a/crates/b/src/anthropic.rs b/crates/b/src/anthropic.rs index 04b89fc..e5514c0 100644 --- a/crates/b/src/anthropic.rs +++ b/crates/b/src/anthropic.rs @@ -16,6 +16,7 @@ impl CompleteCreateCommand { match command { AnthropicCommands::Create { prompt, + system, model, max_tokens_to_sample, stop_sequences, @@ -59,6 +60,10 @@ impl CompleteCreateCommand { log::debug!("model: {:?}", api.model); + if system.is_some() { + api.system = system.clone(); + } + max_tokens_to_sample.map(|s| api.max_tokens_to_sample = Some(s)); max_supported_tokens.map(|s| api.max_supported_tokens = Some(s)); stream.map(|s| api.stream = Some(s)); diff --git a/crates/b/src/lib.rs b/crates/b/src/lib.rs index dbdf808..b8f5bac 100644 --- a/crates/b/src/lib.rs +++ b/crates/b/src/lib.rs @@ -209,6 +209,12 @@ pub enum AnthropicCommands { Create { /// The prompt you want Claude to complete. prompt: String, + /// The system prompt is an optional initial prompt that you could indclude with every + /// message. This is similar to how `system` prompts work with OpenAI Chat GPT models. + /// It's recommended that you use the `\n\nHuman:` and `\n\nAssistant:` stops tokens to + /// create the system prompt. + #[arg(long)] + system: Option, /// Chat session name. Will be used to store previous session interactions. #[arg(long)] session: Option,