Skip to content

Commit

Permalink
Merge pull request #15 from tembo-io/handle_ratelimits
Browse files Browse the repository at this point in the history
Batch requests
  • Loading branch information
ChuckHend authored Oct 23, 2023
2 parents b7e893e + 4cbcf3b commit edd2ed7
Show file tree
Hide file tree
Showing 6 changed files with 174 additions and 68 deletions.
76 changes: 38 additions & 38 deletions .github/workflows/extension_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -76,44 +76,44 @@ jobs:
- name: Clippy
run: cargo clippy

# test:
# name: Run tests
# needs: dependencies
# runs-on: ubuntu-22.04
# steps:
# - uses: actions/checkout@v2
# - name: Install Rust stable toolchain
# uses: actions-rs/toolchain@v1
# with:
# toolchain: stable
# - uses: Swatinem/rust-cache@v2
# with:
# prefix-key: "pg-vectorize-extension-test"
# workspaces: pg-vectorize
# # Additional directories to cache
# cache-directories: /home/runner/.pgrx
# - uses: ./.github/actions/pgx-init
# with:
# working-directory: ./
# - name: Restore cached binaries
# uses: actions/cache@v2
# with:
# path: |
# /usr/local/bin/stoml
# ~/.cargo/bin/trunk
# key: ${{ runner.os }}-bins-${{ github.sha }}
# restore-keys: |
# ${{ runner.os }}-bins-
# - name: test
# run: |
# pgrx15_config=$(/usr/local/bin/stoml ~/.pgrx/config.toml configs.pg15)
# ~/.cargo/bin/trunk install pgvector --pg-config ${pgrx15_config}
# ~/.cargo/bin/trunk install pgmq --pg-config ${pgrx15_config}
# ~/.cargo/bin/trunk install pg_cron --pg-config ${pgrx15_config}
# rm -rf ./target/pgrx-test-data-* || true
# pg_version=$(/usr/local/bin/stoml Cargo.toml features.default)
# cargo pgrx run ${pg_version} --pgcli || true
# cargo pgrx test ${pg_version}
test:
name: Run tests
needs: dependencies
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v2
- name: Install Rust stable toolchain
uses: actions-rs/toolchain@v1
with:
toolchain: stable
- uses: Swatinem/rust-cache@v2
with:
prefix-key: "pg-vectorize-extension-test"
workspaces: pg-vectorize
# Additional directories to cache
cache-directories: /home/runner/.pgrx
- uses: ./.github/actions/pgx-init
with:
working-directory: ./
- name: Restore cached binaries
uses: actions/cache@v2
with:
path: |
/usr/local/bin/stoml
~/.cargo/bin/trunk
key: ${{ runner.os }}-bins-${{ github.sha }}
restore-keys: |
${{ runner.os }}-bins-
- name: test
run: |
pgrx15_config=$(/usr/local/bin/stoml ~/.pgrx/config.toml configs.pg15)
~/.cargo/bin/trunk install pgvector --pg-config ${pgrx15_config}
~/.cargo/bin/trunk install pgmq --pg-config ${pgrx15_config}
~/.cargo/bin/trunk install pg_cron --pg-config ${pgrx15_config}
rm -rf ./target/pgrx-test-data-* || true
pg_version=$(/usr/local/bin/stoml Cargo.toml features.default)
cargo pgrx run ${pg_version} --pgcli || true
cargo pgrx test ${pg_version}
publish:
if: github.event_name == 'release'
Expand Down
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "vectorize"
version = "0.1.0"
version = "0.1.1"
edition = "2021"
publish = false

Expand Down
2 changes: 1 addition & 1 deletion Trunk.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ description = "The simplest implementation of LLM-backed vector search on Postgr
homepage = "https://github.com/tembo-io/pg_vectorize"
documentation = "https://github.com/tembo-io/pg_vectorize"
categories = ["orchestration", "machine_learning"]
version = "0.1.0"
version = "0.1.1"

[build]
postgres_version = "15"
Expand Down
138 changes: 123 additions & 15 deletions src/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,31 @@ pub struct ColumnJobParams {
pub table_method: TableMethod,
}

// creates batches based on total token count
// batch_size is the max token count per batch
fn create_batches(data: Vec<Inputs>, batch_size: i32) -> Vec<Vec<Inputs>> {
let mut groups: Vec<Vec<Inputs>> = Vec::new();
let mut current_group: Vec<Inputs> = Vec::new();
let mut current_token_count = 0;

for input in data {
if current_token_count + input.token_estimate > batch_size {
// Create a new group
groups.push(current_group);
current_group = Vec::new();
current_token_count = 0;
}
current_token_count += input.token_estimate;
current_group.push(input);
}

// Add any remaining inputs to the groups
if !current_group.is_empty() {
groups.push(current_group);
}
groups
}

// schema for all messages that hit pgmq
#[derive(Clone, Deserialize, Debug, Serialize)]
pub struct JobMessage {
Expand All @@ -87,10 +112,14 @@ fn job_execute(job_name: String) {
.build()
.unwrap_or_else(|e| error!("failed to initialize tokio runtime: {}", e));

// TODO: move into a config
// 100k tokens per batch
let max_batch_size = 100000;

runtime.block_on(async {
let conn = get_pg_conn()
.await
.unwrap_or_else(|e| error!("pg-vectorize: failed to establsh db connection: {}", e));
.unwrap_or_else(|e| error!("pg-vectorize: failed to establish db connection: {}", e));
let queue = pgmq::PGMQueueExt::new_with_pool(conn.clone())
.await
.unwrap_or_else(|e| error!("failed to init db connection: {}", e));
Expand All @@ -106,19 +135,28 @@ fn job_execute(job_name: String) {
let new_or_updated_rows = get_new_updates_append(&conn, &job_name, job_params)
.await
.unwrap_or_else(|e| error!("failed to get new updates: {}", e));

match new_or_updated_rows {
Some(rows) => {
log!("num new records: {}", rows.len());
let msg = JobMessage {
job_name: job_name.clone(),
job_meta: meta.clone(),
inputs: rows,
};
let msg_id = queue
.send(PGMQ_QUEUE_NAME, &msg)
.await
.unwrap_or_else(|e| error!("failed to send message updates: {}", e));
log!("message sent: {}", msg_id);
let batches = create_batches(rows, max_batch_size);
log!(
"total batches: {}, max_batch_size: {}",
batches.len(),
max_batch_size
);
for b in batches {
let msg = JobMessage {
job_name: job_name.clone(),
job_meta: meta.clone(),
inputs: b,
};
let msg_id = queue
.send(PGMQ_QUEUE_NAME, &msg)
.await
.unwrap_or_else(|e| error!("failed to send message updates: {}", e));
log!("message sent: {}", msg_id);
}
}
None => {
log!("pg-vectorize: job: {}, no new records", job_name);
Expand Down Expand Up @@ -149,8 +187,9 @@ pub async fn get_vectorize_meta(

#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct Inputs {
pub record_id: String, // the value to join the record
pub inputs: String, // concatenation of input columns
pub record_id: String, // the value to join the record
pub inputs: String, // concatenation of input columns
pub token_estimate: i32, // estimated token count
}

// queries a table and returns rows that need new embeddings
Expand Down Expand Up @@ -188,9 +227,12 @@ pub async fn get_new_updates_append(
if !rows.is_empty() {
let mut new_inputs: Vec<Inputs> = Vec::new();
for r in rows {
let ipt: String = r.get("input_text");
let token_estimate = ipt.split_whitespace().count() as i32;
new_inputs.push(Inputs {
record_id: r.get("record_id"),
inputs: r.get("input_text"),
inputs: ipt,
token_estimate,
})
}
log!("pg-vectorize: num new inputs: {}", new_inputs.len());
Expand Down Expand Up @@ -239,9 +281,12 @@ pub async fn get_new_updates_shared(
match rows {
Ok(rows) => {
for r in rows {
let ipt: String = r.get("input_text");
let token_estimate = ipt.split_whitespace().count() as i32;
new_inputs.push(Inputs {
record_id: r.get("record_id"),
inputs: r.get("input_text"),
inputs: ipt,
token_estimate,
})
}
Ok(Some(new_inputs))
Expand All @@ -261,3 +306,66 @@ fn collapse_to_csv(strings: &[String]) -> String {
.collect::<Vec<_>>()
.join("|| ', ' ||")
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_create_batches_normal() {
let data = vec![
Inputs {
record_id: "1".to_string(),
inputs: "Test 1.".to_string(),
token_estimate: 2,
},
Inputs {
record_id: "2".to_string(),
inputs: "Test 2.".to_string(),
token_estimate: 2,
},
Inputs {
record_id: "3".to_string(),
inputs: "Test 3.".to_string(),
token_estimate: 3,
},
];

let batches = create_batches(data, 4);
assert_eq!(batches.len(), 2);
assert_eq!(batches[0].len(), 2);
assert_eq!(batches[1].len(), 1);
}

#[test]
fn test_create_batches_empty() {
let data: Vec<Inputs> = Vec::new();
let batches = create_batches(data, 4);
assert_eq!(batches.len(), 0);
}

#[test]
fn test_create_batches_large() {
let data = vec![
Inputs {
record_id: "1".to_string(),
inputs: "Test 1.".to_string(),
token_estimate: 2,
},
Inputs {
record_id: "2".to_string(),
inputs: "Test 2.".to_string(),
token_estimate: 2,
},
Inputs {
record_id: "3".to_string(),
inputs: "Test 3.".to_string(),
token_estimate: 100,
},
];
let batches = create_batches(data, 5);
assert_eq!(batches.len(), 2);
assert_eq!(batches[1].len(), 1);
assert_eq!(batches[1][0].token_estimate, 100);
}
}
1 change: 0 additions & 1 deletion src/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ pub async fn get_embeddings(inputs: &Vec<String>, key: &str) -> Result<Vec<Vec<f
Ok(embeddings)
}

// thanks Evan :D
pub async fn handle_response<T: for<'de> serde::Deserialize<'de>>(
resp: reqwest::Response,
method: &'static str,
Expand Down
23 changes: 11 additions & 12 deletions src/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,17 @@ pub extern "C" fn background_worker_main(_arg: pg_sys::Datum) {
// on SIGHUP, you might want to reload configurations and env vars
}
let _: Result<()> = runtime.block_on(async {
let msg: Message<JobMessage> =
match queue.read::<JobMessage>(PGMQ_QUEUE_NAME, 300).await {
Ok(Some(msg)) => msg,
Ok(None) => {
log!("pg-vectorize: No messages in queue");
return Ok(());
}
Err(e) => {
warning!("pg-vectorize: Error reading message: {e}");
return Ok(());
}
};
let msg: Message<JobMessage> = match queue.pop::<JobMessage>(PGMQ_QUEUE_NAME).await {
Ok(Some(msg)) => msg,
Ok(None) => {
log!("pg-vectorize: No messages in queue");
return Ok(());
}
Err(e) => {
warning!("pg-vectorize: Error reading message: {e}");
return Ok(());
}
};

let msg_id = msg.msg_id;
let read_ct = msg.read_ct;
Expand Down

0 comments on commit edd2ed7

Please sign in to comment.