Skip to content

Commit

Permalink
trim inputs that exceeed openai len
Browse files Browse the repository at this point in the history
  • Loading branch information
ChuckHend committed Oct 23, 2023
1 parent edd2ed7 commit 02fcb9f
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 24 deletions.
18 changes: 10 additions & 8 deletions src/api.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use crate::executor::ColumnJobParams;
use crate::init;
use crate::openai::get_embeddings;
use crate::openai::openai_embeddings;
use crate::search::cosine_similarity_search;
use crate::types;
use crate::util;
use anyhow::Result;
use pgrx::prelude::*;

#[allow(clippy::too_many_arguments)]
#[pg_extern]
fn table(
table: &str,
Expand Down Expand Up @@ -144,13 +145,14 @@ fn search(
let schema = project_meta.schema;
let table = project_meta.table;

let embeddings =
match runtime.block_on(async { get_embeddings(&vec![query.to_string()], api_key).await }) {
Ok(e) => e,
Err(e) => {
error!("error getting embeddings: {}", e);
}
};
let embeddings = match runtime
.block_on(async { openai_embeddings(&vec![query.to_string()], api_key).await })
{
Ok(e) => e,
Err(e) => {
error!("error getting embeddings: {}", e);
}
};
let search_results = cosine_similarity_search(
job_name,
&schema,
Expand Down
1 change: 1 addition & 0 deletions src/init.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use anyhow::Result;

pub const PGMQ_QUEUE_NAME: &str = "vectorize_queue";

#[allow(non_camel_case_types)]
#[derive(Clone, Debug, Serialize, Deserialize, PostgresEnum)]
pub enum TableMethod {
// append a new column to the existing table
Expand Down
2 changes: 0 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ extension_sql_file!("../sql/example.sql");
#[cfg(any(test, feature = "pg_test"))]
#[pg_schema]
mod tests {
use pgrx::prelude::*;

// #[pg_test]
// fn test_hello_tembo() {
// assert_eq!("Hello, tembo", crate::hello_tembo());
Expand Down
107 changes: 102 additions & 5 deletions src/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,14 @@ use serde_json::json;

use anyhow::Result;

use crate::executor::Inputs;

// max token length is 8192
// however, depending on content of text, token count can be higher than
// token count returned by split_whitespace()
pub const MAX_TOKEN_LEN: usize = 7500;
pub const OPENAI_EMBEDDING_RL: &str = "https://api.openai.com/v1/embeddings";

#[derive(serde::Deserialize, Debug)]
struct EmbeddingResponse {
// object: String,
Expand All @@ -16,14 +24,31 @@ struct DataObject {
embedding: Vec<f64>,
}

pub async fn get_embeddings(inputs: &Vec<String>, key: &str) -> Result<Vec<Vec<f64>>> {
// let len = inputs.len();
// vec![vec![0.0; 1536]; len]
let url = "https://api.openai.com/v1/embeddings";
// OpenAI embedding model has a limit of 8192 tokens per input
// there can be a number of ways condense the inputs
pub fn trim_inputs(inputs: &[Inputs]) -> Vec<String> {
inputs
.iter()
.map(|input| {
if input.token_estimate as usize > MAX_TOKEN_LEN {
let tokens: Vec<&str> = input.inputs.split_whitespace().collect();
tokens
.into_iter()
.take(MAX_TOKEN_LEN)
.collect::<Vec<_>>()
.join(" ")
} else {
input.inputs.clone()
}
})
.collect()
}

pub async fn openai_embeddings(inputs: &Vec<String>, key: &str) -> Result<Vec<Vec<f64>>> {
log!("pg-vectorize: openai request size: {}", inputs.len());
let client = reqwest::Client::new();
let resp = client
.post(url)
.post(OPENAI_EMBEDDING_RL)
.json(&json!({
"input": inputs,
"model": "text-embedding-ada-002"
Expand Down Expand Up @@ -59,3 +84,75 @@ pub async fn handle_response<T: for<'de> serde::Deserialize<'de>>(
let value = resp.json::<T>().await?;
Ok(value)
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_trim_inputs_no_trimming_required() {
let data = vec![
Inputs {
record_id: "1".to_string(),
inputs: "token1 token2".to_string(),
token_estimate: 2,
},
Inputs {
record_id: "2".to_string(),
inputs: "token3 token4".to_string(),
token_estimate: 2,
},
];

let trimmed = trim_inputs(&data);
assert_eq!(trimmed, vec!["token1 token2", "token3 token4"]);
}

#[test]
fn test_trim_inputs_trimming_required() {
let token_len = 1000000;
let long_input = (0..token_len)
.map(|i| format!("token{}", i))
.collect::<Vec<_>>()
.join(" ");

let num_tokens = long_input.split_whitespace().count();
assert_eq!(num_tokens, token_len);

let data = vec![Inputs {
record_id: "1".to_string(),
inputs: long_input.clone(),
token_estimate: token_len as i32,
}];

let trimmed = trim_inputs(&data);
let trimmed_input = trimmed[0].clone();
let trimmed_length = trimmed_input.split_whitespace().count();
assert_eq!(trimmed_length, MAX_TOKEN_LEN);
}

#[test]
fn test_trim_inputs_mixed_cases() {
let num_tokens_in = 1000000;
let long_input = (0..num_tokens_in)
.map(|i| format!("token{}", i))
.collect::<Vec<_>>()
.join(" ");
let data = vec![
Inputs {
record_id: "1".to_string(),
inputs: "token1 token2".to_string(),
token_estimate: 2,
},
Inputs {
record_id: "2".to_string(),
inputs: long_input.clone(),
token_estimate: num_tokens_in,
},
];

let trimmed = trim_inputs(&data);
assert_eq!(trimmed[0].split_whitespace().count(), 2);
assert_eq!(trimmed[1].split_whitespace().count(), MAX_TOKEN_LEN);
}
}
2 changes: 1 addition & 1 deletion src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ pub fn get_vectorize_meta_spi(job_name: &str) -> Option<pgrx::JsonB> {
vec![(PgBuiltInOids::TEXTOID.oid(), job_name.into_datum())],
);
if let Ok(r) = resultset {
return r;
r
} else {
error!("failed to query vectorize metadata table")
}
Expand Down
12 changes: 4 additions & 8 deletions src/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,13 +200,7 @@ async fn execute_job(dbclient: Pool<Postgres>, msg: Message<JobMessage>) -> Resu
let embeddings: Result<Vec<PairedEmbeddings>> = match job_meta.transformer {
types::Transformer::openai => {
log!("pg-vectorize: OpenAI transformer");
let text_inputs: Vec<String> = msg
.message
.inputs
.clone()
.into_iter()
.map(|v| v.inputs)
.collect();

let apikey = match job_params
.api_key
.ok_or_else(|| anyhow::anyhow!("missing api key"))
Expand All @@ -218,7 +212,9 @@ async fn execute_job(dbclient: Pool<Postgres>, msg: Message<JobMessage>) -> Resu
}
};

let embeddings = match openai::get_embeddings(&text_inputs, &apikey).await {
// trims any inputs that exceed openAIs max token length
let text_inputs = openai::trim_inputs(&msg.message.inputs);
let embeddings = match openai::openai_embeddings(&text_inputs, &apikey).await {
Ok(e) => e,
Err(e) => {
warning!("pg-vectorize: Error getting embeddings: {}", e);
Expand Down

0 comments on commit 02fcb9f

Please sign in to comment.