diff --git a/.github/workflows/extension_ci.yml b/.github/workflows/extension_ci.yml index 8b847ad..f901132 100644 --- a/.github/workflows/extension_ci.yml +++ b/.github/workflows/extension_ci.yml @@ -76,44 +76,44 @@ jobs: - name: Clippy run: cargo clippy - # test: - # name: Run tests - # needs: dependencies - # runs-on: ubuntu-22.04 - # steps: - # - uses: actions/checkout@v2 - # - name: Install Rust stable toolchain - # uses: actions-rs/toolchain@v1 - # with: - # toolchain: stable - # - uses: Swatinem/rust-cache@v2 - # with: - # prefix-key: "pg-vectorize-extension-test" - # workspaces: pg-vectorize - # # Additional directories to cache - # cache-directories: /home/runner/.pgrx - # - uses: ./.github/actions/pgx-init - # with: - # working-directory: ./ - # - name: Restore cached binaries - # uses: actions/cache@v2 - # with: - # path: | - # /usr/local/bin/stoml - # ~/.cargo/bin/trunk - # key: ${{ runner.os }}-bins-${{ github.sha }} - # restore-keys: | - # ${{ runner.os }}-bins- - # - name: test - # run: | - # pgrx15_config=$(/usr/local/bin/stoml ~/.pgrx/config.toml configs.pg15) - # ~/.cargo/bin/trunk install pgvector --pg-config ${pgrx15_config} - # ~/.cargo/bin/trunk install pgmq --pg-config ${pgrx15_config} - # ~/.cargo/bin/trunk install pg_cron --pg-config ${pgrx15_config} - # rm -rf ./target/pgrx-test-data-* || true - # pg_version=$(/usr/local/bin/stoml Cargo.toml features.default) - # cargo pgrx run ${pg_version} --pgcli || true - # cargo pgrx test ${pg_version} + test: + name: Run tests + needs: dependencies + runs-on: ubuntu-22.04 + steps: + - uses: actions/checkout@v2 + - name: Install Rust stable toolchain + uses: actions-rs/toolchain@v1 + with: + toolchain: stable + - uses: Swatinem/rust-cache@v2 + with: + prefix-key: "pg-vectorize-extension-test" + workspaces: pg-vectorize + # Additional directories to cache + cache-directories: /home/runner/.pgrx + - uses: ./.github/actions/pgx-init + with: + working-directory: ./ + - name: Restore cached binaries + uses: actions/cache@v2 + with: + path: | + /usr/local/bin/stoml + ~/.cargo/bin/trunk + key: ${{ runner.os }}-bins-${{ github.sha }} + restore-keys: | + ${{ runner.os }}-bins- + - name: test + run: | + pgrx15_config=$(/usr/local/bin/stoml ~/.pgrx/config.toml configs.pg15) + ~/.cargo/bin/trunk install pgvector --pg-config ${pgrx15_config} + ~/.cargo/bin/trunk install pgmq --pg-config ${pgrx15_config} + ~/.cargo/bin/trunk install pg_cron --pg-config ${pgrx15_config} + rm -rf ./target/pgrx-test-data-* || true + pg_version=$(/usr/local/bin/stoml Cargo.toml features.default) + cargo pgrx run ${pg_version} --pgcli || true + cargo pgrx test ${pg_version} publish: if: github.event_name == 'release' diff --git a/Cargo.toml b/Cargo.toml index 24b2973..e6bbe90 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "vectorize" -version = "0.1.0" +version = "0.1.1" edition = "2021" publish = false diff --git a/Trunk.toml b/Trunk.toml index 9c489f1..8091fc7 100644 --- a/Trunk.toml +++ b/Trunk.toml @@ -6,7 +6,7 @@ description = "The simplest implementation of LLM-backed vector search on Postgr homepage = "https://github.com/tembo-io/pg_vectorize" documentation = "https://github.com/tembo-io/pg_vectorize" categories = ["orchestration", "machine_learning"] -version = "0.1.0" +version = "0.1.1" [build] postgres_version = "15" diff --git a/src/executor.rs b/src/executor.rs index e206a5b..d8a37fb 100644 --- a/src/executor.rs +++ b/src/executor.rs @@ -67,6 +67,31 @@ pub struct ColumnJobParams { pub table_method: TableMethod, } +// creates batches based on total token count +// batch_size is the max token count per batch +fn create_batches(data: Vec, batch_size: i32) -> Vec> { + let mut groups: Vec> = Vec::new(); + let mut current_group: Vec = Vec::new(); + let mut current_token_count = 0; + + for input in data { + if current_token_count + input.token_estimate > batch_size { + // Create a new group + groups.push(current_group); + current_group = Vec::new(); + current_token_count = 0; + } + current_token_count += input.token_estimate; + current_group.push(input); + } + + // Add any remaining inputs to the groups + if !current_group.is_empty() { + groups.push(current_group); + } + groups +} + // schema for all messages that hit pgmq #[derive(Clone, Deserialize, Debug, Serialize)] pub struct JobMessage { @@ -87,10 +112,14 @@ fn job_execute(job_name: String) { .build() .unwrap_or_else(|e| error!("failed to initialize tokio runtime: {}", e)); + // TODO: move into a config + // 100k tokens per batch + let max_batch_size = 100000; + runtime.block_on(async { let conn = get_pg_conn() .await - .unwrap_or_else(|e| error!("pg-vectorize: failed to establsh db connection: {}", e)); + .unwrap_or_else(|e| error!("pg-vectorize: failed to establish db connection: {}", e)); let queue = pgmq::PGMQueueExt::new_with_pool(conn.clone()) .await .unwrap_or_else(|e| error!("failed to init db connection: {}", e)); @@ -106,19 +135,28 @@ fn job_execute(job_name: String) { let new_or_updated_rows = get_new_updates_append(&conn, &job_name, job_params) .await .unwrap_or_else(|e| error!("failed to get new updates: {}", e)); + match new_or_updated_rows { Some(rows) => { log!("num new records: {}", rows.len()); - let msg = JobMessage { - job_name: job_name.clone(), - job_meta: meta.clone(), - inputs: rows, - }; - let msg_id = queue - .send(PGMQ_QUEUE_NAME, &msg) - .await - .unwrap_or_else(|e| error!("failed to send message updates: {}", e)); - log!("message sent: {}", msg_id); + let batches = create_batches(rows, max_batch_size); + log!( + "total batches: {}, max_batch_size: {}", + batches.len(), + max_batch_size + ); + for b in batches { + let msg = JobMessage { + job_name: job_name.clone(), + job_meta: meta.clone(), + inputs: b, + }; + let msg_id = queue + .send(PGMQ_QUEUE_NAME, &msg) + .await + .unwrap_or_else(|e| error!("failed to send message updates: {}", e)); + log!("message sent: {}", msg_id); + } } None => { log!("pg-vectorize: job: {}, no new records", job_name); @@ -149,8 +187,9 @@ pub async fn get_vectorize_meta( #[derive(Clone, Debug, Deserialize, Serialize)] pub struct Inputs { - pub record_id: String, // the value to join the record - pub inputs: String, // concatenation of input columns + pub record_id: String, // the value to join the record + pub inputs: String, // concatenation of input columns + pub token_estimate: i32, // estimated token count } // queries a table and returns rows that need new embeddings @@ -188,9 +227,12 @@ pub async fn get_new_updates_append( if !rows.is_empty() { let mut new_inputs: Vec = Vec::new(); for r in rows { + let ipt: String = r.get("input_text"); + let token_estimate = ipt.split_whitespace().count() as i32; new_inputs.push(Inputs { record_id: r.get("record_id"), - inputs: r.get("input_text"), + inputs: ipt, + token_estimate, }) } log!("pg-vectorize: num new inputs: {}", new_inputs.len()); @@ -239,9 +281,12 @@ pub async fn get_new_updates_shared( match rows { Ok(rows) => { for r in rows { + let ipt: String = r.get("input_text"); + let token_estimate = ipt.split_whitespace().count() as i32; new_inputs.push(Inputs { record_id: r.get("record_id"), - inputs: r.get("input_text"), + inputs: ipt, + token_estimate, }) } Ok(Some(new_inputs)) @@ -261,3 +306,66 @@ fn collapse_to_csv(strings: &[String]) -> String { .collect::>() .join("|| ', ' ||") } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_create_batches_normal() { + let data = vec![ + Inputs { + record_id: "1".to_string(), + inputs: "Test 1.".to_string(), + token_estimate: 2, + }, + Inputs { + record_id: "2".to_string(), + inputs: "Test 2.".to_string(), + token_estimate: 2, + }, + Inputs { + record_id: "3".to_string(), + inputs: "Test 3.".to_string(), + token_estimate: 3, + }, + ]; + + let batches = create_batches(data, 4); + assert_eq!(batches.len(), 2); + assert_eq!(batches[0].len(), 2); + assert_eq!(batches[1].len(), 1); + } + + #[test] + fn test_create_batches_empty() { + let data: Vec = Vec::new(); + let batches = create_batches(data, 4); + assert_eq!(batches.len(), 0); + } + + #[test] + fn test_create_batches_large() { + let data = vec![ + Inputs { + record_id: "1".to_string(), + inputs: "Test 1.".to_string(), + token_estimate: 2, + }, + Inputs { + record_id: "2".to_string(), + inputs: "Test 2.".to_string(), + token_estimate: 2, + }, + Inputs { + record_id: "3".to_string(), + inputs: "Test 3.".to_string(), + token_estimate: 100, + }, + ]; + let batches = create_batches(data, 5); + assert_eq!(batches.len(), 2); + assert_eq!(batches[1].len(), 1); + assert_eq!(batches[1][0].token_estimate, 100); + } +} diff --git a/src/openai.rs b/src/openai.rs index ce20c96..84be26f 100644 --- a/src/openai.rs +++ b/src/openai.rs @@ -42,7 +42,6 @@ pub async fn get_embeddings(inputs: &Vec, key: &str) -> Result serde::Deserialize<'de>>( resp: reqwest::Response, method: &'static str, diff --git a/src/worker.rs b/src/worker.rs index 4a31ac1..f788e24 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -44,18 +44,17 @@ pub extern "C" fn background_worker_main(_arg: pg_sys::Datum) { // on SIGHUP, you might want to reload configurations and env vars } let _: Result<()> = runtime.block_on(async { - let msg: Message = - match queue.read::(PGMQ_QUEUE_NAME, 300).await { - Ok(Some(msg)) => msg, - Ok(None) => { - log!("pg-vectorize: No messages in queue"); - return Ok(()); - } - Err(e) => { - warning!("pg-vectorize: Error reading message: {e}"); - return Ok(()); - } - }; + let msg: Message = match queue.pop::(PGMQ_QUEUE_NAME).await { + Ok(Some(msg)) => msg, + Ok(None) => { + log!("pg-vectorize: No messages in queue"); + return Ok(()); + } + Err(e) => { + warning!("pg-vectorize: Error reading message: {e}"); + return Ok(()); + } + }; let msg_id = msg.msg_id; let read_ct = msg.read_ct;