From 02fcb9f4a6cb474671c54c8cb2f2fe2222d31be6 Mon Sep 17 00:00:00 2001 From: Adam Hendel <15756360+ChuckHend@users.noreply.github.com> Date: Mon, 23 Oct 2023 14:19:00 -0500 Subject: [PATCH] trim inputs that exceeed openai len --- src/api.rs | 18 +++++---- src/init.rs | 1 + src/lib.rs | 2 - src/openai.rs | 107 +++++++++++++++++++++++++++++++++++++++++++++++--- src/util.rs | 2 +- src/worker.rs | 12 ++---- 6 files changed, 118 insertions(+), 24 deletions(-) diff --git a/src/api.rs b/src/api.rs index 26f7520..a3677ac 100644 --- a/src/api.rs +++ b/src/api.rs @@ -1,12 +1,13 @@ use crate::executor::ColumnJobParams; use crate::init; -use crate::openai::get_embeddings; +use crate::openai::openai_embeddings; use crate::search::cosine_similarity_search; use crate::types; use crate::util; use anyhow::Result; use pgrx::prelude::*; +#[allow(clippy::too_many_arguments)] #[pg_extern] fn table( table: &str, @@ -144,13 +145,14 @@ fn search( let schema = project_meta.schema; let table = project_meta.table; - let embeddings = - match runtime.block_on(async { get_embeddings(&vec![query.to_string()], api_key).await }) { - Ok(e) => e, - Err(e) => { - error!("error getting embeddings: {}", e); - } - }; + let embeddings = match runtime + .block_on(async { openai_embeddings(&vec![query.to_string()], api_key).await }) + { + Ok(e) => e, + Err(e) => { + error!("error getting embeddings: {}", e); + } + }; let search_results = cosine_similarity_search( job_name, &schema, diff --git a/src/init.rs b/src/init.rs index bb9a488..c7e9d2f 100644 --- a/src/init.rs +++ b/src/init.rs @@ -6,6 +6,7 @@ use anyhow::Result; pub const PGMQ_QUEUE_NAME: &str = "vectorize_queue"; +#[allow(non_camel_case_types)] #[derive(Clone, Debug, Serialize, Deserialize, PostgresEnum)] pub enum TableMethod { // append a new column to the existing table diff --git a/src/lib.rs b/src/lib.rs index e472bdb..6a93192 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -21,8 +21,6 @@ extension_sql_file!("../sql/example.sql"); #[cfg(any(test, feature = "pg_test"))] #[pg_schema] mod tests { - use pgrx::prelude::*; - // #[pg_test] // fn test_hello_tembo() { // assert_eq!("Hello, tembo", crate::hello_tembo()); diff --git a/src/openai.rs b/src/openai.rs index 84be26f..7956231 100644 --- a/src/openai.rs +++ b/src/openai.rs @@ -3,6 +3,14 @@ use serde_json::json; use anyhow::Result; +use crate::executor::Inputs; + +// max token length is 8192 +// however, depending on content of text, token count can be higher than +// token count returned by split_whitespace() +pub const MAX_TOKEN_LEN: usize = 7500; +pub const OPENAI_EMBEDDING_RL: &str = "https://api.openai.com/v1/embeddings"; + #[derive(serde::Deserialize, Debug)] struct EmbeddingResponse { // object: String, @@ -16,14 +24,31 @@ struct DataObject { embedding: Vec, } -pub async fn get_embeddings(inputs: &Vec, key: &str) -> Result>> { - // let len = inputs.len(); - // vec![vec![0.0; 1536]; len] - let url = "https://api.openai.com/v1/embeddings"; +// OpenAI embedding model has a limit of 8192 tokens per input +// there can be a number of ways condense the inputs +pub fn trim_inputs(inputs: &[Inputs]) -> Vec { + inputs + .iter() + .map(|input| { + if input.token_estimate as usize > MAX_TOKEN_LEN { + let tokens: Vec<&str> = input.inputs.split_whitespace().collect(); + tokens + .into_iter() + .take(MAX_TOKEN_LEN) + .collect::>() + .join(" ") + } else { + input.inputs.clone() + } + }) + .collect() +} + +pub async fn openai_embeddings(inputs: &Vec, key: &str) -> Result>> { log!("pg-vectorize: openai request size: {}", inputs.len()); let client = reqwest::Client::new(); let resp = client - .post(url) + .post(OPENAI_EMBEDDING_RL) .json(&json!({ "input": inputs, "model": "text-embedding-ada-002" @@ -59,3 +84,75 @@ pub async fn handle_response serde::Deserialize<'de>>( let value = resp.json::().await?; Ok(value) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_trim_inputs_no_trimming_required() { + let data = vec![ + Inputs { + record_id: "1".to_string(), + inputs: "token1 token2".to_string(), + token_estimate: 2, + }, + Inputs { + record_id: "2".to_string(), + inputs: "token3 token4".to_string(), + token_estimate: 2, + }, + ]; + + let trimmed = trim_inputs(&data); + assert_eq!(trimmed, vec!["token1 token2", "token3 token4"]); + } + + #[test] + fn test_trim_inputs_trimming_required() { + let token_len = 1000000; + let long_input = (0..token_len) + .map(|i| format!("token{}", i)) + .collect::>() + .join(" "); + + let num_tokens = long_input.split_whitespace().count(); + assert_eq!(num_tokens, token_len); + + let data = vec![Inputs { + record_id: "1".to_string(), + inputs: long_input.clone(), + token_estimate: token_len as i32, + }]; + + let trimmed = trim_inputs(&data); + let trimmed_input = trimmed[0].clone(); + let trimmed_length = trimmed_input.split_whitespace().count(); + assert_eq!(trimmed_length, MAX_TOKEN_LEN); + } + + #[test] + fn test_trim_inputs_mixed_cases() { + let num_tokens_in = 1000000; + let long_input = (0..num_tokens_in) + .map(|i| format!("token{}", i)) + .collect::>() + .join(" "); + let data = vec![ + Inputs { + record_id: "1".to_string(), + inputs: "token1 token2".to_string(), + token_estimate: 2, + }, + Inputs { + record_id: "2".to_string(), + inputs: long_input.clone(), + token_estimate: num_tokens_in, + }, + ]; + + let trimmed = trim_inputs(&data); + assert_eq!(trimmed[0].split_whitespace().count(), 2); + assert_eq!(trimmed[1].split_whitespace().count(), MAX_TOKEN_LEN); + } +} diff --git a/src/util.rs b/src/util.rs index a7ced54..85b3434 100644 --- a/src/util.rs +++ b/src/util.rs @@ -69,7 +69,7 @@ pub fn get_vectorize_meta_spi(job_name: &str) -> Option { vec![(PgBuiltInOids::TEXTOID.oid(), job_name.into_datum())], ); if let Ok(r) = resultset { - return r; + r } else { error!("failed to query vectorize metadata table") } diff --git a/src/worker.rs b/src/worker.rs index f788e24..df70ddb 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -200,13 +200,7 @@ async fn execute_job(dbclient: Pool, msg: Message) -> Resu let embeddings: Result> = match job_meta.transformer { types::Transformer::openai => { log!("pg-vectorize: OpenAI transformer"); - let text_inputs: Vec = msg - .message - .inputs - .clone() - .into_iter() - .map(|v| v.inputs) - .collect(); + let apikey = match job_params .api_key .ok_or_else(|| anyhow::anyhow!("missing api key")) @@ -218,7 +212,9 @@ async fn execute_job(dbclient: Pool, msg: Message) -> Resu } }; - let embeddings = match openai::get_embeddings(&text_inputs, &apikey).await { + // trims any inputs that exceed openAIs max token length + let text_inputs = openai::trim_inputs(&msg.message.inputs); + let embeddings = match openai::openai_embeddings(&text_inputs, &apikey).await { Ok(e) => e, Err(e) => { warning!("pg-vectorize: Error getting embeddings: {}", e);