Skip to content

Commit

Permalink
improve error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
ChuckHend committed Oct 22, 2023
1 parent 6ab45af commit d77b0af
Show file tree
Hide file tree
Showing 10 changed files with 208 additions and 194 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/extension_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ jobs:
key: ${{ runner.os }}-bins-${{ github.sha }}
restore-keys: |
${{ runner.os }}-bins-
lint:
name: Run linters
runs-on: ubuntu-22.04
Expand Down
12 changes: 8 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "vectorize"
version = "0.0.2"
version = "0.1.0"
edition = "2021"
publish = false

Expand All @@ -16,16 +16,20 @@ pg_test = []
anyhow = "1.0.72"
chrono = {version = "0.4.26", features = ["serde"] }
log = "0.4.19"
pgmq = "0.14.0"
pgmq = "0.24.0"
pgrx = ">=0.8.0"
postgres-types = "0.2.5"
regex = "1.9.2"
reqwest = {version = "0.11.18", features = ["json"] }
serde = "1.0.173"
serde_json = "1.0.103"
sqlx = {version = "0.6.3", features = [ "offline", "runtime-tokio-native-tls" , "postgres", "chrono", "json" ] }
sqlx = { version = "0.7.2", features = [
"runtime-tokio-native-tls",
"postgres",
"chrono",
] }
thiserror = "1.0.44"
tokio = "1.29.1"
tokio = {version = "1.29.1", features = ["rt-multi-thread"] }
url = "2.4.0"

[dev-dependencies]
Expand Down
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
SQLX_OFFLINE:=true
DATABASE_URL:=postgres://${USER}:${USER}@localhost:28815/postgres

sqlx-cache:
cargo sqlx prepare
Expand All @@ -9,4 +10,4 @@ format:

# ensure the DATABASE_URL is not used, since pgrx will stop postgres during compile
run:
SQLX_OFFLINE=true cargo pgrx run
SQLX_OFFLINE=true DATABASE_URL=${DATABASE_URL} cargo pgrx run pg15 postgres
2 changes: 1 addition & 1 deletion Trunk.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ description = "The simplest implementation of LLM-backed vector search on Postgr
homepage = "https://github.com/tembo-io/pg_vectorize"
documentation = "https://github.com/tembo-io/pg_vectorize"
categories = ["orchestration", "machine_learning"]
version = "0.0.2"
version = "0.1.0"

[build]
postgres_version = "15"
Expand Down
104 changes: 63 additions & 41 deletions src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::openai::get_embeddings;
use crate::search::cosine_similarity_search;
use crate::types;
use crate::util;
use anyhow::Result;
use pgrx::prelude::*;

#[pg_extern]
Expand All @@ -19,9 +20,9 @@ fn table(
search_alg: default!(types::SimilarityAlg, "'pgv_cosine_similarity'"),
table_method: default!(init::TableMethod, "'append'"),
schedule: default!(String, "'* * * * *'"),
) -> String {
) -> Result<String> {
// initialize pgmq
init::init_pgmq().expect("error initializing pgmq");
init::init_pgmq()?;
let job_type = types::JobType::Columns;

// write job to table
Expand All @@ -30,7 +31,12 @@ fn table(
Some(a) => a,
None => format!("{}_{}_{}", schema, table, columns.join("_")),
};
let arguments = serde_json::to_value(args).expect("invalid json for argument `args`");
let arguments = match serde_json::to_value(args) {
Ok(a) => a,
Err(e) => {
error!("invalid json for argument `args`: {}", e);
}
};
let api_key = arguments.get("api_key");

// get prim key type
Expand All @@ -50,31 +56,36 @@ fn table(
// using SPI here because it is unlikely that this code will be run anywhere but inside the extension
// background worker will likely be moved to an external container or service in near future
let ran: Result<_, spi::Error> = Spi::connect(|mut c| {
let _ = c
.update(
&init_job_q,
None,
Some(vec![
(PgBuiltInOids::TEXTOID.oid(), job_name.clone().into_datum()),
(
PgBuiltInOids::TEXTOID.oid(),
job_type.to_string().into_datum(),
),
(
PgBuiltInOids::TEXTOID.oid(),
transformer.to_string().into_datum(),
),
(
PgBuiltInOids::TEXTOID.oid(),
search_alg.to_string().into_datum(),
),
(PgBuiltInOids::JSONBOID.oid(), params.into_datum()),
]),
)
.expect("error exec query");
match c.update(
&init_job_q,
None,
Some(vec![
(PgBuiltInOids::TEXTOID.oid(), job_name.clone().into_datum()),
(
PgBuiltInOids::TEXTOID.oid(),
job_type.to_string().into_datum(),
),
(
PgBuiltInOids::TEXTOID.oid(),
transformer.to_string().into_datum(),
),
(
PgBuiltInOids::TEXTOID.oid(),
search_alg.to_string().into_datum(),
),
(PgBuiltInOids::JSONBOID.oid(), params.into_datum()),
]),
) {
Ok(_) => (),
Err(e) => {
error!("error creating job: {}", e);
}
}
Ok(())
});
ran.expect("error creating job");
if ran.is_err() {
error!("error creating job");
}
let init_embed_q = init::init_embedding_table_query(
&job_name,
&schema,
Expand All @@ -85,14 +96,18 @@ fn table(
);

let ran: Result<_, spi::Error> = Spi::connect(|mut c| {
let _ = c.update(&init_embed_q, None, None);
let _r = c.update(&init_embed_q, None, None)?;
Ok(())
});
ran.expect("error creating embedding table");
if let Err(e) = ran {
error!("error creating embedding table: {}", e);
}
// TODO: first batch update
// then cron
// initialize cron
let _ = init::init_cron(&schedule, &job_name); // handle this error
format!("{schema}.{table}.{columns:?}.{transformer}.{search_alg}")
Ok(format!(
"{schema}.{table}.{columns:?}.{transformer}.{search_alg}"
))
}

#[pg_extern]
Expand All @@ -107,28 +122,35 @@ fn search(
// this requires a query to metadata table to get the projects schema and table, which has a cost
// this does ensure consistency between the model used to generate the stored embeddings and the query embeddings, which is crucial

// TODO: simplify api signature as much as possible
// get project metadata
let _project_meta =
util::get_vectorize_meta_spi(job_name).expect("metadata for project is missing");
let project_meta: ColumnJobParams = serde_json::from_value(
serde_json::to_value(_project_meta).expect("failed to deserialize metadata"),
)
.expect("failed to deserialize metadata");
let _project_meta = if let Some(js) = util::get_vectorize_meta_spi(job_name) {
js
} else {
error!("failed to get project metadata");
};
let project_meta: ColumnJobParams =
serde_json::from_value(serde_json::to_value(_project_meta).unwrap_or_else(|e| {
error!("failed to serialize metadata: {}", e);
}))
.unwrap_or_else(|e| error!("failed to serialize metadata: {}", e));
// assuming default openai API for now
// get embeddings
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_io()
.enable_time()
.build()
.unwrap();
.unwrap_or_else(|e| error!("failed to initialize tokio runtime: {}", e));

let schema = project_meta.schema;
let table = project_meta.table;

let embeddings = runtime
.block_on(async { get_embeddings(&vec![query.to_string()], api_key).await })
.expect("failed getting embeddings");
let embeddings =
match runtime.block_on(async { get_embeddings(&vec![query.to_string()], api_key).await }) {
Ok(e) => e,
Err(e) => {
error!("error getting embeddings: {}", e);
}
};
let search_results = cosine_similarity_search(
job_name,
&schema,
Expand Down
63 changes: 7 additions & 56 deletions src/executor.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use pgrx::prelude::*;
use pgrx::spi::SpiTupleTable;

use crate::errors::DatabaseError;
use crate::init::{TableMethod, PGMQ_QUEUE_NAME};
Expand Down Expand Up @@ -86,25 +85,25 @@ fn job_execute(job_name: String) {
.enable_io()
.enable_time()
.build()
.unwrap();
.unwrap_or_else(|e| error!("failed to initialize tokio runtime: {}", e));

runtime.block_on(async {
let conn = get_pg_conn().await.expect("failed to connect to database");
let conn = get_pg_conn().await.unwrap_or_else(|e| error!("pg-vectorize: failed to establsh db connection: {}", e));
let queue = pgmq::PGMQueueExt::new_with_pool(conn.clone())
.await
.expect("failed to init db connection");
.unwrap_or_else(|e| error!("failed to init db connection: {}", e));
let meta = get_vectorize_meta(&job_name, &conn)
.await
.expect("failed to get job meta");
.unwrap_or_else(|e| error!("failed to get job metadata: {}", e));
let job_params = serde_json::from_value::<ColumnJobParams>(meta.params.clone())
.expect("failed to deserialize job params");
.unwrap_or_else(|e| error!("failed to deserialize job params: {}", e));
let _last_completion = match meta.last_completion {
Some(t) => t,
None => Utc.with_ymd_and_hms(970, 1, 1, 0, 0, 0).unwrap(),
};
let new_or_updated_rows = get_new_updates_append(&conn, &job_name, job_params)
.await
.expect("failed to get new updates");
.unwrap_or_else(|e| error!("failed to get new updates: {}", e));
match new_or_updated_rows {
Some(rows) => {
log!("num new records: {}", rows.len());
Expand All @@ -116,7 +115,7 @@ fn job_execute(job_name: String) {
let msg_id = queue
.send(PGMQ_QUEUE_NAME, &msg)
.await
.expect("failed to send message");
.unwrap_or_else(|e| error!("failed to send message updates: {}", e));
log!("message sent: {}", msg_id);
}
None => {
Expand Down Expand Up @@ -250,54 +249,6 @@ pub async fn get_new_updates_shared(
}
}

// gets last processed times
fn get_inputs_query(
job_name: &str,
schema: &str,
table: &str,
columns: Vec<String>,
last_updated_col: &str,
) -> String {
let cols = collapse_to_csv(&columns);

format!(
"
SELECT {cols} as input_text
FROM {schema}.{table}
WHERE {last_updated_col} >
(
SELECT last_completion
FROM vectorize_meta
WHERE name = '{job_name}'
)::timestamp
"
)
}

// retrieves inputs for embedding model
#[pg_extern]
fn get_inputs(
job_name: &str,
schema: &str,
table: &str,
columns: Vec<String>,
updated_at_col: &str,
) -> Vec<String> {
let mut results: Vec<String> = Vec::new();
let query = get_inputs_query(job_name, schema, table, columns, updated_at_col);
let _: Result<(), pgrx::spi::Error> = Spi::connect(|mut client: spi::SpiClient<'_>| {
let tup_table: SpiTupleTable = client.update(&query, None, None)?;
for row in tup_table {
let input = row["input_text"]
.value::<String>()?
.expect("input column missing");
results.push(input);
}
Ok(())
});
results
}

fn collapse_to_csv(strings: &[String]) -> String {
strings
.iter()
Expand Down
18 changes: 12 additions & 6 deletions src/init.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ use crate::{query::check_input, types};
use pgrx::prelude::*;
use serde::{Deserialize, Serialize};

use anyhow::Result;

pub const PGMQ_QUEUE_NAME: &str = "vectorize_queue";

#[derive(Clone, Debug, Serialize, Deserialize, PostgresEnum)]
Expand All @@ -12,15 +14,19 @@ pub enum TableMethod {
join,
}

pub fn init_pgmq() -> Result<(), spi::Error> {
Spi::connect(|mut c| {
let _ = c.update(
&format!("SELECT pgmq_create('{PGMQ_QUEUE_NAME}');"),
pub fn init_pgmq() -> Result<()> {
let ran: Result<_, spi::Error> = Spi::connect(|mut c| {
let _r = c.update(
&format!("SELECT pgmq.create('{PGMQ_QUEUE_NAME}');"),
None,
None,
);
)?;
Ok(())
})
});
if let Err(e) = ran {
error!("error creating embedding table: {}", e);
}
Ok(())
}

pub fn init_cron(cron: &str, job_name: &str) -> Result<Option<i64>, spi::Error> {
Expand Down
6 changes: 3 additions & 3 deletions src/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ pub async fn get_embeddings(inputs: &Vec<String>, key: &str) -> Result<Vec<Vec<f
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", key))
.send()
.await
.expect("failed calling openai");
.await?;
let embedding_resp = handle_response::<EmbeddingResponse>(resp, "embeddings").await?;

let embeddings = embedding_resp
Expand All @@ -55,7 +54,8 @@ pub async fn handle_response<T: for<'de> serde::Deserialize<'de>>(
resp.status(),
resp.text().await?
);
error!("{}", errmsg);
warning!("pg-vectorize: error handling response: {}", errmsg);
return Err(anyhow::anyhow!(errmsg));
}
let value = resp.json::<T>().await?;
Ok(value)
Expand Down
Loading

0 comments on commit d77b0af

Please sign in to comment.