Skip to content

Commit

Permalink
route messages for addt'l transformer types (#25)
Browse files Browse the repository at this point in the history
* move gucs

* separate pg bgw as worker

* fix imports

* add upgrade script

* remove dead code

* rename job table

* sqlx-cli 0.7.2

* rename job table

* redundant model

* estimate token count

* pull batch size from config

* bump toml
  • Loading branch information
ChuckHend authored Nov 20, 2023
1 parent f83b983 commit 589bcca
Show file tree
Hide file tree
Showing 18 changed files with 440 additions and 452 deletions.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "vectorize"
version = "0.5.0"
version = "0.6.0"
edition = "2021"
publish = false

Expand All @@ -15,6 +15,7 @@ pg_test = []
[dependencies]
anyhow = "1.0.72"
chrono = {version = "0.4.26", features = ["serde"] }
lazy_static = "1.4.0"
log = "0.4.19"
pgmq = "0.24.0"
pgrx = "0.11.0"
Expand All @@ -29,6 +30,7 @@ sqlx = { version = "0.7.2", features = [
"chrono",
] }
thiserror = "1.0.44"
tiktoken-rs = "0.5.7"
tokio = {version = "1.29.1", features = ["rt-multi-thread"] }
url = "2.4.0"

Expand Down
2 changes: 1 addition & 1 deletion Trunk.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ description = "The simplest way to orchestrate vector search on Postgres."
homepage = "https://github.com/tembo-io/pg_vectorize"
documentation = "https://github.com/tembo-io/pg_vectorize"
categories = ["orchestration", "machine_learning"]
version = "0.5.0"
version = "0.6.0"

[build]
postgres_version = "15"
Expand Down
4 changes: 1 addition & 3 deletions sql/meta.sql
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
CREATE TABLE vectorize_meta (
CREATE TABLE vectorize.job (
job_id bigserial,
name TEXT NOT NULL UNIQUE,
job_type TEXT NOT NULL,
Expand All @@ -7,5 +7,3 @@ CREATE TABLE vectorize_meta (
params jsonb NOT NULL,
last_completion TIMESTAMP WITH TIME ZONE
);

CREATE EXTENSION IF NOT EXISTS pgmq CASCADE;
21 changes: 21 additions & 0 deletions sql/vectorize--0.5.0--0.6.0.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
DROP function vectorize."table";

-- vectorize::api::table
CREATE FUNCTION vectorize."table"(
"table" TEXT, /* &str */
"columns" TEXT[], /* alloc::vec::Vec<alloc::string::String> */
"job_name" TEXT, /* alloc::string::String */
"primary_key" TEXT, /* alloc::string::String */
"args" json DEFAULT '{}', /* pgrx::datum::json::Json */
"schema" TEXT DEFAULT 'public', /* alloc::string::String */
"update_col" TEXT DEFAULT 'last_updated_at', /* alloc::string::String */
"transformer" vectorize.Transformer DEFAULT 'openai', /* vectorize::types::Transformer */
"search_alg" vectorize.SimilarityAlg DEFAULT 'pgv_cosine_similarity', /* vectorize::types::SimilarityAlg */
"table_method" vectorize.TableMethod DEFAULT 'append', /* vectorize::init::TableMethod */
"schedule" TEXT DEFAULT '* * * * *' /* alloc::string::String */
) RETURNS TEXT /* core::result::Result<alloc::string::String, anyhow::Error> */
STRICT
LANGUAGE c /* Rust */
AS 'MODULE_PATHNAME', 'table_wrapper';

ALTER TABLE vectorize.vectorize_meta RENAME TO vectorize.job;
59 changes: 0 additions & 59 deletions sqlx-data.json

This file was deleted.

41 changes: 22 additions & 19 deletions src/api.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use crate::executor::ColumnJobParams;
use crate::guc;
use crate::init;
use crate::openai;
use crate::search::cosine_similarity_search;
use crate::transformers::openai;
use crate::types;
use crate::types::JobParams;
use crate::util;
use anyhow::Result;
use pgrx::prelude::*;
Expand All @@ -19,11 +20,9 @@ fn table(
update_col: default!(String, "'last_updated_at'"),
transformer: default!(types::Transformer, "'openai'"),
search_alg: default!(types::SimilarityAlg, "'pgv_cosine_similarity'"),
table_method: default!(init::TableMethod, "'append'"),
table_method: default!(types::TableMethod, "'append'"),
schedule: default!(String, "'* * * * *'"),
) -> Result<String> {
// initialize pgmq
init::init_pgmq()?;
let job_type = types::JobType::Columns;

// write job to table
Expand All @@ -41,11 +40,12 @@ fn table(

// certain embedding services require an API key, e.g. openAI
// key can be set in a GUC, so if its required but not provided in args, and not in GUC, error
init::init_pgmq(&transformer)?;
match transformer {
types::Transformer::openai => {
let openai_key = match api_key {
Some(k) => serde_json::from_value::<String>(k.clone())?,
None => match util::get_guc(util::VectorizeGuc::OpenAIKey) {
None => match guc::get_guc(guc::VectorizeGuc::OpenAIKey) {
Some(k) => k,
None => {
error!("failed to get API key from GUC");
Expand All @@ -54,19 +54,22 @@ fn table(
};
openai::validate_api_key(&openai_key)?;
}
// no-op
types::Transformer::allMiniLML12v2 => (),
}

// TODO: implement a struct for these params
let params = pgrx::JsonB(serde_json::json!({
"schema": schema,
"table": table,
"columns": columns,
"update_time_col": update_col,
"table_method": table_method,
"primary_key": primary_key,
"pkey_type": pkey_type,
"api_key": api_key
}));
let valid_params = types::JobParams {
schema: schema.clone(),
table: table.to_string(),
columns: columns.clone(),
update_time_col: update_col,
table_method: table_method.clone(),
primary_key,
pkey_type,
api_key: api_key
.map(|k| serde_json::from_value::<String>(k.clone()).expect("error parsing api key")),
};
let params = pgrx::JsonB(serde_json::to_value(valid_params).expect("error serializing params"));

// using SPI here because it is unlikely that this code will be run anywhere but inside the extension.
// background worker will likely be moved to an external container or service in near future
Expand Down Expand Up @@ -145,7 +148,7 @@ fn search(
} else {
error!("failed to get project metadata");
};
let project_meta: ColumnJobParams =
let project_meta: JobParams =
serde_json::from_value(serde_json::to_value(_project_meta).unwrap_or_else(|e| {
error!("failed to serialize metadata: {}", e);
}))
Expand All @@ -163,7 +166,7 @@ fn search(

let openai_key = match api_key {
Some(k) => k,
None => match util::get_guc(util::VectorizeGuc::OpenAIKey) {
None => match guc::get_guc(guc::VectorizeGuc::OpenAIKey) {
Some(k) => k,
None => {
error!("failed to get API key from GUC");
Expand Down
71 changes: 18 additions & 53 deletions src/executor.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use pgrx::prelude::*;

use crate::errors::DatabaseError;
use crate::init::{TableMethod, PGMQ_QUEUE_NAME};
use crate::guc::BATCH_SIZE;
use crate::init::QUEUE_MAPPING;
use crate::query::check_input;
use crate::types;
use crate::util::{from_env_default, get_pg_conn};
Expand All @@ -12,6 +13,7 @@ use sqlx::error::Error;
use sqlx::postgres::PgRow;
use sqlx::types::chrono::Utc;
use sqlx::{FromRow, PgPool, Pool, Postgres, Row};
use tiktoken_rs::cl100k_base;

// schema for every job
// also schema for the vectorize.vectorize_meta table
Expand All @@ -27,46 +29,6 @@ pub struct VectorizeMeta {
pub last_completion: Option<chrono::DateTime<Utc>>,
}

// temporary struct for deserializing from db
// not needed when sqlx 0.7.x
#[derive(Clone, Debug, Deserialize, FromRow, Serialize, PostgresType)]
pub struct _VectorizeMeta {
pub job_id: i64,
pub name: String,
pub job_type: String,
pub transformer: String,
pub search_alg: String,
pub params: serde_json::Value,
#[serde(deserialize_with = "from_tsopt")]
pub last_completion: Option<chrono::DateTime<Utc>>,
}

impl From<_VectorizeMeta> for VectorizeMeta {
fn from(val: _VectorizeMeta) -> Self {
VectorizeMeta {
job_id: val.job_id,
name: val.name,
job_type: types::JobType::from(val.job_type),
transformer: types::Transformer::from(val.transformer),
search_alg: types::SimilarityAlg::from(val.search_alg),
params: val.params,
last_completion: val.last_completion,
}
}
}

#[derive(Clone, Deserialize, Debug, Serialize)]
pub struct ColumnJobParams {
pub schema: String,
pub table: String,
pub columns: Vec<String>,
pub primary_key: String,
pub pkey_type: String,
pub update_time_col: String,
pub api_key: Option<String>,
pub table_method: TableMethod,
}

// creates batches based on total token count
// batch_size is the max token count per batch
fn create_batches(data: Vec<Inputs>, batch_size: i32) -> Vec<Vec<Inputs>> {
Expand Down Expand Up @@ -112,9 +74,7 @@ fn job_execute(job_name: String) {
.build()
.unwrap_or_else(|e| error!("failed to initialize tokio runtime: {}", e));

// TODO: move into a config
// 100k tokens per batch
let max_batch_size = 100000;
let max_batch_size = BATCH_SIZE.get();

runtime.block_on(async {
let conn = get_pg_conn()
Expand All @@ -126,7 +86,7 @@ fn job_execute(job_name: String) {
let meta = get_vectorize_meta(&job_name, &conn)
.await
.unwrap_or_else(|e| error!("failed to get job metadata: {}", e));
let job_params = serde_json::from_value::<ColumnJobParams>(meta.params.clone())
let job_params = serde_json::from_value::<types::JobParams>(meta.params.clone())
.unwrap_or_else(|e| error!("failed to deserialize job params: {}", e));
let _last_completion = match meta.last_completion {
Some(t) => t,
Expand All @@ -151,8 +111,11 @@ fn job_execute(job_name: String) {
job_meta: meta.clone(),
inputs: b,
};
let queue_name = QUEUE_MAPPING
.get(&meta.transformer)
.expect("invalid transformer");
let msg_id = queue
.send(PGMQ_QUEUE_NAME, &msg)
.send(queue_name, &msg)
.await
.unwrap_or_else(|e| error!("failed to send message updates: {}", e));
log!("message sent: {}", msg_id);
Expand All @@ -172,17 +135,17 @@ pub async fn get_vectorize_meta(
) -> Result<VectorizeMeta, DatabaseError> {
log!("fetching job: {}", job_name);
let row = sqlx::query_as!(
_VectorizeMeta,
VectorizeMeta,
"
SELECT *
FROM vectorize.vectorize_meta
FROM vectorize.job
WHERE name = $1
",
job_name.to_string(),
)
.fetch_one(conn)
.await?;
Ok(row.into())
Ok(row)
}

#[derive(Clone, Debug, Deserialize, Serialize)]
Expand All @@ -197,7 +160,7 @@ pub struct Inputs {
pub async fn get_new_updates_append(
pool: &Pool<Postgres>,
job_name: &str,
job_params: ColumnJobParams,
job_params: types::JobParams,
) -> Result<Option<Vec<Inputs>>, DatabaseError> {
let cols = collapse_to_csv(&job_params.columns);

Expand Down Expand Up @@ -225,10 +188,11 @@ pub async fn get_new_updates_append(
match rows {
Ok(rows) => {
if !rows.is_empty() {
let bpe = cl100k_base().unwrap();
let mut new_inputs: Vec<Inputs> = Vec::new();
for r in rows {
let ipt: String = r.get("input_text");
let token_estimate = ipt.split_whitespace().count() as i32;
let token_estimate = bpe.encode_with_special_tokens(&ipt).len() as i32;
new_inputs.push(Inputs {
record_id: r.get("record_id"),
inputs: ipt,
Expand All @@ -249,7 +213,7 @@ pub async fn get_new_updates_append(
// queries a table and returns rows that need new embeddings
#[allow(dead_code)]
pub async fn get_new_updates_shared(
job_params: ColumnJobParams,
job_params: types::JobParams,
last_completion: chrono::DateTime<Utc>,
) -> Result<Option<Vec<Inputs>>, DatabaseError> {
let pool = PgPool::connect(&from_env_default(
Expand Down Expand Up @@ -280,9 +244,10 @@ pub async fn get_new_updates_shared(
let rows: Result<Vec<PgRow>, Error> = sqlx::query(&new_rows_query).fetch_all(&pool).await;
match rows {
Ok(rows) => {
let bpe = cl100k_base().unwrap();
for r in rows {
let ipt: String = r.get("input_text");
let token_estimate = ipt.split_whitespace().count() as i32;
let token_estimate = bpe.encode_with_special_tokens(&ipt).len() as i32;
new_inputs.push(Inputs {
record_id: r.get("record_id"),
inputs: ipt,
Expand Down
Loading

0 comments on commit 589bcca

Please sign in to comment.