diff --git a/Cargo.toml b/Cargo.toml index d0132d2..8ec33b3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "vectorize" -version = "0.6.1" +version = "0.7.0" edition = "2021" publish = false diff --git a/sql/vectorize--0.7.0--0.7.1.sql b/sql/vectorize--0.7.0--0.7.1.sql new file mode 100644 index 0000000..13b6675 --- /dev/null +++ b/sql/vectorize--0.7.0--0.7.1.sql @@ -0,0 +1,30 @@ +DROP function vectorize."table"; + +-- src/api.rs:14 +-- vectorize::api::table +CREATE FUNCTION vectorize."table"( + "table" TEXT, /* &str */ + "columns" TEXT[], /* alloc::vec::Vec */ + "job_name" TEXT, /* alloc::string::String */ + "primary_key" TEXT, /* alloc::string::String */ + "args" json DEFAULT '{}', /* pgrx::datum::json::Json */ + "schema" TEXT DEFAULT 'public', /* alloc::string::String */ + "update_col" TEXT DEFAULT 'last_updated_at', /* alloc::string::String */ + "transformer" vectorize.Transformer DEFAULT 'text_embedding_ada_002', /* vectorize::types::Transformer */ + "search_alg" vectorize.SimilarityAlg DEFAULT 'pgv_cosine_similarity', /* vectorize::types::SimilarityAlg */ + "table_method" vectorize.TableMethod DEFAULT 'append', /* vectorize::types::TableMethod */ + "schedule" TEXT DEFAULT '* * * * *' /* alloc::string::String */ +) RETURNS TEXT /* core::result::Result */ +STRICT +LANGUAGE c /* Rust */ +AS 'MODULE_PATHNAME', 'table_wrapper'; + +-- src/api.rs:172 +-- vectorize::api::transform_embeddings +CREATE FUNCTION vectorize."transform_embeddings"( + "input" TEXT, /* &str */ + "model_name" vectorize.Transformer DEFAULT 'text_embedding_ada_002', /* vectorize::types::Transformer */ + "api_key" TEXT DEFAULT NULL /* core::option::Option */ +) RETURNS double precision[] /* core::result::Result, pgrx::spi::SpiError> */ +LANGUAGE c /* Rust */ +AS 'MODULE_PATHNAME', 'transform_embeddings_wrapper'; diff --git a/src/api.rs b/src/api.rs index c1b7d29..73de480 100644 --- a/src/api.rs +++ b/src/api.rs @@ -1,12 +1,8 @@ use crate::executor::VectorizeMeta; use crate::guc; -use crate::guc::get_guc; use crate::init; use crate::search::cosine_similarity_search; -use crate::transformers::{ - http_handler::openai_embedding_request, openai, openai::OPENAI_EMBEDDING_MODEL, - openai::OPENAI_EMBEDDING_URL, types::EmbeddingPayload, types::EmbeddingRequest, -}; +use crate::transformers::{openai, transform}; use crate::types; use crate::types::JobParams; use crate::util; @@ -23,7 +19,7 @@ fn table( args: default!(pgrx::Json, "'{}'"), schema: default!(String, "'public'"), update_col: default!(String, "'last_updated_at'"), - transformer: default!(types::Transformer, "'openai'"), + transformer: default!(types::Transformer, "'text_embedding_ada_002'"), search_alg: default!(types::SimilarityAlg, "'pgv_cosine_similarity'"), table_method: default!(types::TableMethod, "'append'"), schedule: default!(String, "'* * * * *'"), @@ -47,7 +43,7 @@ fn table( // key can be set in a GUC, so if its required but not provided in args, and not in GUC, error init::init_pgmq(&transformer)?; match transformer { - types::Transformer::openai => { + types::Transformer::text_embedding_ada_002 => { let openai_key = match api_key { Some(k) => serde_json::from_value::(k.clone())?, None => match guc::get_guc(guc::VectorizeGuc::OpenAIKey) { @@ -153,60 +149,10 @@ fn search( ) .unwrap_or_else(|e| error!("failed to deserialize metadata: {}", e)); - // get embeddings - let runtime = tokio::runtime::Builder::new_current_thread() - .enable_io() - .enable_time() - .build() - .unwrap_or_else(|e| error!("failed to initialize tokio runtime: {}", e)); - let schema = proj_params.schema; let table = proj_params.table; - let embedding_request = match project_meta.transformer { - types::Transformer::openai => { - let openai_key = match api_key { - Some(k) => k, - None => match guc::get_guc(guc::VectorizeGuc::OpenAIKey) { - Some(k) => k, - None => { - error!("failed to get API key from GUC"); - } - }, - }; - - let embedding_request = EmbeddingPayload { - input: vec![query.to_string()], - model: OPENAI_EMBEDDING_MODEL.to_string(), - }; - EmbeddingRequest { - url: OPENAI_EMBEDDING_URL.to_owned(), - payload: embedding_request, - api_key: Some(openai_key), - } - } - types::Transformer::all_MiniLM_L12_v2 => { - let url: String = get_guc(guc::VectorizeGuc::EmbeddingServiceUrl) - .expect("failed to get embedding service url from GUC"); - let embedding_request = EmbeddingPayload { - input: vec![query.to_string()], - model: project_meta.transformer.to_string(), - }; - EmbeddingRequest { - url, - payload: embedding_request, - api_key: None, - } - } - }; - - let embeddings = - match runtime.block_on(async { openai_embedding_request(embedding_request).await }) { - Ok(e) => e, - Err(e) => { - error!("error getting embeddings: {}", e); - } - }; + let embeddings = transform(query, project_meta.transformer, api_key); let search_results = match project_meta.search_alg { types::SimilarityAlg::pgv_cosine_similarity => cosine_similarity_search( @@ -221,3 +167,12 @@ fn search( Ok(TableIterator::new(search_results)) } + +#[pg_extern] +fn transform_embeddings( + input: &str, + model_name: default!(types::Transformer, "'text_embedding_ada_002'"), + api_key: default!(Option, "NULL"), +) -> Result, spi::Error> { + Ok(transform(input, model_name, api_key).remove(0)) +} diff --git a/src/init.rs b/src/init.rs index cb8ce0b..22d4325 100644 --- a/src/init.rs +++ b/src/init.rs @@ -10,7 +10,7 @@ lazy_static! { // maintain the mapping of transformer to queue name here pub static ref QUEUE_MAPPING: HashMap = { let mut m = HashMap::new(); - m.insert(Transformer::openai, "v_openai"); + m.insert(Transformer::text_embedding_ada_002, "v_openai"); m.insert(Transformer::all_MiniLM_L12_v2, "v_all_MiniLM_L12_v2"); m }; @@ -84,7 +84,10 @@ pub fn init_embedding_table_query( // so that they can be read here, not hard-coded here below // currently only supports the text-embedding-ada-002 embedding model - output dim 1536 // https://platform.openai.com/docs/guides/embeddings/what-are-embeddings - (types::Transformer::openai, types::SimilarityAlg::pgv_cosine_similarity) => "vector(1536)", + ( + types::Transformer::text_embedding_ada_002, + types::SimilarityAlg::pgv_cosine_similarity, + ) => "vector(1536)", (types::Transformer::all_MiniLM_L12_v2, types::SimilarityAlg::pgv_cosine_similarity) => { "vector(384)" } diff --git a/src/transformers/mod.rs b/src/transformers/mod.rs index 3653ce9..6b43cd5 100644 --- a/src/transformers/mod.rs +++ b/src/transformers/mod.rs @@ -3,3 +3,61 @@ pub mod http_handler; pub mod openai; pub mod tembo; pub mod types; + +use crate::guc; +use crate::types::Transformer; +use http_handler::openai_embedding_request; +use openai::{OPENAI_EMBEDDING_MODEL, OPENAI_EMBEDDING_URL}; +use pgrx::prelude::*; +use types::{EmbeddingPayload, EmbeddingRequest}; + +pub fn transform(input: &str, transformer: Transformer, api_key: Option) -> Vec> { + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_io() + .enable_time() + .build() + .unwrap_or_else(|e| error!("failed to initialize tokio runtime: {}", e)); + + let embedding_request = match transformer { + Transformer::text_embedding_ada_002 => { + let openai_key = match api_key { + Some(k) => k, + None => match guc::get_guc(guc::VectorizeGuc::OpenAIKey) { + Some(k) => k, + None => { + error!("failed to get API key from GUC"); + } + }, + }; + + let embedding_request = EmbeddingPayload { + input: vec![input.to_string()], + model: OPENAI_EMBEDDING_MODEL.to_string(), + }; + EmbeddingRequest { + url: OPENAI_EMBEDDING_URL.to_owned(), + payload: embedding_request, + api_key: Some(openai_key), + } + } + Transformer::all_MiniLM_L12_v2 => { + let url: String = guc::get_guc(guc::VectorizeGuc::EmbeddingServiceUrl) + .expect("failed to get embedding service url from GUC"); + let embedding_request = EmbeddingPayload { + input: vec![input.to_string()], + model: transformer.to_string(), + }; + EmbeddingRequest { + url, + payload: embedding_request, + api_key: None, + } + } + }; + match runtime.block_on(async { openai_embedding_request(embedding_request).await }) { + Ok(e) => e, + Err(e) => { + error!("error getting embeddings: {}", e); + } + } +} diff --git a/src/transformers/openai.rs b/src/transformers/openai.rs index 6e26db6..ad9c0a0 100644 --- a/src/transformers/openai.rs +++ b/src/transformers/openai.rs @@ -3,6 +3,7 @@ use pgrx::prelude::*; use anyhow::Result; use crate::{ + executor::VectorizeMeta, guc::OPENAI_KEY, transformers::{ http_handler::handle_response, @@ -18,14 +19,16 @@ pub const OPENAI_EMBEDDING_URL: &str = "https://api.openai.com/v1/embeddings"; pub const OPENAI_EMBEDDING_MODEL: &str = "text-embedding-ada-002"; pub fn prepare_openai_request( - job_params: JobParams, + vect_meta: VectorizeMeta, inputs: &[Inputs], ) -> Result { let text_inputs = trim_inputs(inputs); + let job_params: JobParams = serde_json::from_value(vect_meta.params.clone())?; let payload = EmbeddingPayload { input: text_inputs, model: OPENAI_EMBEDDING_MODEL.to_owned(), }; + let apikey = match job_params.api_key { Some(k) => k, None => { diff --git a/src/types.rs b/src/types.rs index eb72d8c..d6c9002 100644 --- a/src/types.rs +++ b/src/types.rs @@ -8,7 +8,7 @@ pub const VECTORIZE_SCHEMA: &str = "vectorize"; #[allow(non_camel_case_types)] #[derive(Clone, Debug, Serialize, Deserialize, Eq, Hash, PartialEq, PostgresEnum)] pub enum Transformer { - openai, + text_embedding_ada_002, all_MiniLM_L12_v2, } @@ -17,7 +17,7 @@ impl FromStr for Transformer { fn from_str(s: &str) -> Result { match s { - "openai" => Ok(Transformer::openai), + "text_embedding_ada_002" => Ok(Transformer::text_embedding_ada_002), "all_MiniLM_L12_v2" => Ok(Transformer::all_MiniLM_L12_v2), _ => Err(format!("Invalid value: {}", s)), } @@ -27,7 +27,7 @@ impl FromStr for Transformer { impl From for Transformer { fn from(s: String) -> Self { match s.as_str() { - "openai" => Transformer::openai, + "text_embedding_ada_002" => Transformer::text_embedding_ada_002, "all_MiniLM_L12_v2" => Transformer::all_MiniLM_L12_v2, _ => panic!("Invalid value for Transformer: {}", s), // or handle this case differently } @@ -37,7 +37,7 @@ impl From for Transformer { impl Display for Transformer { fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> { match self { - Transformer::openai => write!(f, "openai"), + Transformer::text_embedding_ada_002 => write!(f, "text_embedding_ada_002"), Transformer::all_MiniLM_L12_v2 => write!(f, "all_MiniLM_L12_v2"), } } diff --git a/src/workers/mod.rs b/src/workers/mod.rs index ea40f97..ee8f5e4 100644 --- a/src/workers/mod.rs +++ b/src/workers/mod.rs @@ -155,9 +155,9 @@ async fn execute_job(dbclient: Pool, msg: Message) -> Resu let job_params: types::JobParams = serde_json::from_value(job_meta.params.clone())?; let embedding_request = match job_meta.transformer { - types::Transformer::openai => { + types::Transformer::text_embedding_ada_002 => { log!("pg-vectorize: OpenAI transformer"); - openai::prepare_openai_request(job_params.clone(), &msg.message.inputs) + openai::prepare_openai_request(job_meta.clone(), &msg.message.inputs) } types::Transformer::all_MiniLM_L12_v2 => { generic::prepare_generic_embedding_request(job_meta.clone(), &msg.message.inputs) diff --git a/src/workers/pg_bgw.rs b/src/workers/pg_bgw.rs index 8f6d17e..f45f3cf 100644 --- a/src/workers/pg_bgw.rs +++ b/src/workers/pg_bgw.rs @@ -40,9 +40,8 @@ pub extern "C" fn background_worker_main(_arg: pg_sys::Datum) { log!("Starting BG Workers {}", BackgroundWorker::get_name(),); - // bgw only supports the OpenAI transformer case let oai_q = QUEUE_MAPPING - .get(&Transformer::openai) + .get(&Transformer::text_embedding_ada_002) .expect("invalid transformer"); let aux_q = QUEUE_MAPPING .get(&Transformer::all_MiniLM_L12_v2)