Skip to content

Commit

Permalink
feat(anthropic): add optional system prompt
Browse files Browse the repository at this point in the history
- Add `system` field to `Create` command
- Handle system prompt by prepending to user prompt
- Adjust prompt trimming to account for system prompt
- Cleanup multiple newlines and extra spaces in prompts
  • Loading branch information
cloudbridgeuy committed Jun 14, 2023
1 parent ad134be commit e9c7b4f
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 9 deletions.
37 changes: 28 additions & 9 deletions crates/anthropic/src/complete.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ pub struct Api {
// Complete Properties (https://console.anthropic.com/docs/api/reference)
pub model: Model,
pub prompt: String,
pub system: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens_to_sample: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
Expand Down Expand Up @@ -176,10 +177,17 @@ impl Api {
api.max_supported_tokens = None;
api.session = None;

api.prompt = trim_prompt(
api.prompt.to_string(),
max_supported_tokens - max_tokens_to_sample,
)?;
let mut max = max_supported_tokens - max_tokens_to_sample;

if let Some(system) = &api.system {
log::debug!("system: {:?}", system);
api.prompt = prepare_prompt(format!("{}\n{}", system, prompt));
max -= token_length(system.to_string()) as u32
}

api.prompt = trim_prompt(api.prompt.to_string(), max)?;

log::debug!("trimmed prompt: {}", api.prompt);

let request = serde_json::to_string(api)?;

Expand Down Expand Up @@ -270,19 +278,26 @@ pub fn serialize_sessions_file(session_file: &str, complete_api: &Api) -> Result
Ok(())
}

/// Token language of a prompt.
/// TODO: Make this better!
fn token_length(prompt: String) -> usize {
let words = prompt.split_whitespace().rev().collect::<Vec<&str>>();

// Estimate the total tokens by multiplying words by 4/3
words.len() * 4 / 3
}

/// Trims the size of the prompt to match the max value.
fn trim_prompt(prompt: String, max: u32) -> Result<String> {
let prompt = prepare_prompt(prompt);

let mut words = prompt.split_whitespace().rev().collect::<Vec<&str>>();

// Estimate the total tokens by multiplying words by 4/3
let tokens = words.len() * 4 / 3;
let tokens = token_length(prompt.clone());

if tokens as u32 <= max {
return Ok(prompt);
}

let mut words = prompt.split_whitespace().rev().collect::<Vec<&str>>();

// Because we need to add back "\n\nHuman:" back to the prompt.
let diff = words.len() - (max + 3) as usize;

Expand All @@ -298,9 +313,13 @@ fn trim_prompt(prompt: String, max: u32) -> Result<String> {
fn prepare_prompt(prompt: String) -> String {
let mut prompt = "\n\nHuman: ".to_string() + &prompt + "\n\nAssistant:";

prompt = prompt.replace("\n\n\n", "\n\n");
prompt = prompt.replace("Human: Human:", "Human:");
prompt = prompt.replace("\n\nHuman:\n\nHuman: ", "\n\nHuman: ");
prompt = prompt.replace("\n\nHuman: \nHuman: ", "\n\nHuman: ");
prompt = prompt.replace("\n\nHuman: \n\nHuman: ", "\n\nHuman: ");
prompt = prompt.replace("\n\nHuman:\n\nAssistant:", "\n\nAssistant:");
prompt = prompt.replace("\n\nHuman: \n\nAssistant:", "\n\nAssistant:");
prompt = prompt.replace("\n\nAssistant:\n\nAssistant:", "\n\nAssistant:");
prompt = prompt.replace("\n\nAssistant: \n\nAssistant: ", "\n\nAssistant:");

Expand Down
5 changes: 5 additions & 0 deletions crates/b/src/anthropic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ impl CompleteCreateCommand {
match command {
AnthropicCommands::Create {
prompt,
system,
model,
max_tokens_to_sample,
stop_sequences,
Expand Down Expand Up @@ -59,6 +60,10 @@ impl CompleteCreateCommand {

log::debug!("model: {:?}", api.model);

if system.is_some() {
api.system = system.clone();
}

max_tokens_to_sample.map(|s| api.max_tokens_to_sample = Some(s));
max_supported_tokens.map(|s| api.max_supported_tokens = Some(s));
stream.map(|s| api.stream = Some(s));
Expand Down
6 changes: 6 additions & 0 deletions crates/b/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,12 @@ pub enum AnthropicCommands {
Create {
/// The prompt you want Claude to complete.
prompt: String,
/// The system prompt is an optional initial prompt that you could indclude with every
/// message. This is similar to how `system` prompts work with OpenAI Chat GPT models.
/// It's recommended that you use the `\n\nHuman:` and `\n\nAssistant:` stops tokens to
/// create the system prompt.
#[arg(long)]
system: Option<String>,
/// Chat session name. Will be used to store previous session interactions.
#[arg(long)]
session: Option<String>,
Expand Down

0 comments on commit e9c7b4f

Please sign in to comment.