Skip to content

Commit

Permalink
add transform to sql api
Browse files Browse the repository at this point in the history
  • Loading branch information
ChuckHend committed Dec 8, 2023
1 parent 31a7168 commit 9639986
Show file tree
Hide file tree
Showing 9 changed files with 118 additions and 70 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "vectorize"
version = "0.6.1"
version = "0.7.0"
edition = "2021"
publish = false

Expand Down
30 changes: 30 additions & 0 deletions sql/vectorize--0.7.0--0.7.1.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
DROP function vectorize."table";

-- src/api.rs:14
-- vectorize::api::table
CREATE FUNCTION vectorize."table"(
"table" TEXT, /* &str */
"columns" TEXT[], /* alloc::vec::Vec<alloc::string::String> */
"job_name" TEXT, /* alloc::string::String */
"primary_key" TEXT, /* alloc::string::String */
"args" json DEFAULT '{}', /* pgrx::datum::json::Json */
"schema" TEXT DEFAULT 'public', /* alloc::string::String */
"update_col" TEXT DEFAULT 'last_updated_at', /* alloc::string::String */
"transformer" vectorize.Transformer DEFAULT 'text_embedding_ada_002', /* vectorize::types::Transformer */
"search_alg" vectorize.SimilarityAlg DEFAULT 'pgv_cosine_similarity', /* vectorize::types::SimilarityAlg */
"table_method" vectorize.TableMethod DEFAULT 'append', /* vectorize::types::TableMethod */
"schedule" TEXT DEFAULT '* * * * *' /* alloc::string::String */
) RETURNS TEXT /* core::result::Result<alloc::string::String, anyhow::Error> */
STRICT
LANGUAGE c /* Rust */
AS 'MODULE_PATHNAME', 'table_wrapper';

-- src/api.rs:172
-- vectorize::api::transform_embeddings
CREATE FUNCTION vectorize."transform_embeddings"(
"input" TEXT, /* &str */
"model_name" vectorize.Transformer DEFAULT 'text_embedding_ada_002', /* vectorize::types::Transformer */
"api_key" TEXT DEFAULT NULL /* core::option::Option<alloc::string::String> */
) RETURNS double precision[] /* core::result::Result<alloc::vec::Vec<f64>, pgrx::spi::SpiError> */
LANGUAGE c /* Rust */
AS 'MODULE_PATHNAME', 'transform_embeddings_wrapper';
71 changes: 13 additions & 58 deletions src/api.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
use crate::executor::VectorizeMeta;
use crate::guc;
use crate::guc::get_guc;
use crate::init;
use crate::search::cosine_similarity_search;
use crate::transformers::{
http_handler::openai_embedding_request, openai, openai::OPENAI_EMBEDDING_MODEL,
openai::OPENAI_EMBEDDING_URL, types::EmbeddingPayload, types::EmbeddingRequest,
};
use crate::transformers::{openai, transform};
use crate::types;
use crate::types::JobParams;
use crate::util;
Expand All @@ -23,7 +19,7 @@ fn table(
args: default!(pgrx::Json, "'{}'"),
schema: default!(String, "'public'"),
update_col: default!(String, "'last_updated_at'"),
transformer: default!(types::Transformer, "'openai'"),
transformer: default!(types::Transformer, "'text_embedding_ada_002'"),
search_alg: default!(types::SimilarityAlg, "'pgv_cosine_similarity'"),
table_method: default!(types::TableMethod, "'append'"),
schedule: default!(String, "'* * * * *'"),
Expand All @@ -47,7 +43,7 @@ fn table(
// key can be set in a GUC, so if its required but not provided in args, and not in GUC, error
init::init_pgmq(&transformer)?;
match transformer {
types::Transformer::openai => {
types::Transformer::text_embedding_ada_002 => {
let openai_key = match api_key {
Some(k) => serde_json::from_value::<String>(k.clone())?,
None => match guc::get_guc(guc::VectorizeGuc::OpenAIKey) {
Expand Down Expand Up @@ -153,60 +149,10 @@ fn search(
)
.unwrap_or_else(|e| error!("failed to deserialize metadata: {}", e));

// get embeddings
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_io()
.enable_time()
.build()
.unwrap_or_else(|e| error!("failed to initialize tokio runtime: {}", e));

let schema = proj_params.schema;
let table = proj_params.table;

let embedding_request = match project_meta.transformer {
types::Transformer::openai => {
let openai_key = match api_key {
Some(k) => k,
None => match guc::get_guc(guc::VectorizeGuc::OpenAIKey) {
Some(k) => k,
None => {
error!("failed to get API key from GUC");
}
},
};

let embedding_request = EmbeddingPayload {
input: vec![query.to_string()],
model: OPENAI_EMBEDDING_MODEL.to_string(),
};
EmbeddingRequest {
url: OPENAI_EMBEDDING_URL.to_owned(),
payload: embedding_request,
api_key: Some(openai_key),
}
}
types::Transformer::all_MiniLM_L12_v2 => {
let url: String = get_guc(guc::VectorizeGuc::EmbeddingServiceUrl)
.expect("failed to get embedding service url from GUC");
let embedding_request = EmbeddingPayload {
input: vec![query.to_string()],
model: project_meta.transformer.to_string(),
};
EmbeddingRequest {
url,
payload: embedding_request,
api_key: None,
}
}
};

let embeddings =
match runtime.block_on(async { openai_embedding_request(embedding_request).await }) {
Ok(e) => e,
Err(e) => {
error!("error getting embeddings: {}", e);
}
};
let embeddings = transform(query, project_meta.transformer, api_key);

let search_results = match project_meta.search_alg {
types::SimilarityAlg::pgv_cosine_similarity => cosine_similarity_search(
Expand All @@ -221,3 +167,12 @@ fn search(

Ok(TableIterator::new(search_results))
}

#[pg_extern]
fn transform_embeddings(
input: &str,
model_name: default!(types::Transformer, "'text_embedding_ada_002'"),
api_key: default!(Option<String>, "NULL"),
) -> Result<Vec<f64>, spi::Error> {
Ok(transform(input, model_name, api_key).remove(0))
}
7 changes: 5 additions & 2 deletions src/init.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ lazy_static! {
// maintain the mapping of transformer to queue name here
pub static ref QUEUE_MAPPING: HashMap<Transformer, &'static str> = {
let mut m = HashMap::new();
m.insert(Transformer::openai, "v_openai");
m.insert(Transformer::text_embedding_ada_002, "v_openai");
m.insert(Transformer::all_MiniLM_L12_v2, "v_all_MiniLM_L12_v2");
m
};
Expand Down Expand Up @@ -84,7 +84,10 @@ pub fn init_embedding_table_query(
// so that they can be read here, not hard-coded here below
// currently only supports the text-embedding-ada-002 embedding model - output dim 1536
// https://platform.openai.com/docs/guides/embeddings/what-are-embeddings
(types::Transformer::openai, types::SimilarityAlg::pgv_cosine_similarity) => "vector(1536)",
(
types::Transformer::text_embedding_ada_002,
types::SimilarityAlg::pgv_cosine_similarity,
) => "vector(1536)",
(types::Transformer::all_MiniLM_L12_v2, types::SimilarityAlg::pgv_cosine_similarity) => {
"vector(384)"
}
Expand Down
58 changes: 58 additions & 0 deletions src/transformers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,61 @@ pub mod http_handler;
pub mod openai;
pub mod tembo;
pub mod types;

use crate::guc;
use crate::types::Transformer;
use http_handler::openai_embedding_request;
use openai::{OPENAI_EMBEDDING_MODEL, OPENAI_EMBEDDING_URL};
use pgrx::prelude::*;
use types::{EmbeddingPayload, EmbeddingRequest};

pub fn transform(input: &str, transformer: Transformer, api_key: Option<String>) -> Vec<Vec<f64>> {
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_io()
.enable_time()
.build()
.unwrap_or_else(|e| error!("failed to initialize tokio runtime: {}", e));

let embedding_request = match transformer {
Transformer::text_embedding_ada_002 => {
let openai_key = match api_key {
Some(k) => k,
None => match guc::get_guc(guc::VectorizeGuc::OpenAIKey) {
Some(k) => k,
None => {
error!("failed to get API key from GUC");
}
},
};

let embedding_request = EmbeddingPayload {
input: vec![input.to_string()],
model: OPENAI_EMBEDDING_MODEL.to_string(),
};
EmbeddingRequest {
url: OPENAI_EMBEDDING_URL.to_owned(),
payload: embedding_request,
api_key: Some(openai_key),
}
}
Transformer::all_MiniLM_L12_v2 => {
let url: String = guc::get_guc(guc::VectorizeGuc::EmbeddingServiceUrl)
.expect("failed to get embedding service url from GUC");
let embedding_request = EmbeddingPayload {
input: vec![input.to_string()],
model: transformer.to_string(),
};
EmbeddingRequest {
url,
payload: embedding_request,
api_key: None,
}
}
};
match runtime.block_on(async { openai_embedding_request(embedding_request).await }) {
Ok(e) => e,
Err(e) => {
error!("error getting embeddings: {}", e);
}
}
}
5 changes: 4 additions & 1 deletion src/transformers/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use pgrx::prelude::*;
use anyhow::Result;

use crate::{
executor::VectorizeMeta,
guc::OPENAI_KEY,
transformers::{
http_handler::handle_response,
Expand All @@ -18,14 +19,16 @@ pub const OPENAI_EMBEDDING_URL: &str = "https://api.openai.com/v1/embeddings";
pub const OPENAI_EMBEDDING_MODEL: &str = "text-embedding-ada-002";

pub fn prepare_openai_request(
job_params: JobParams,
vect_meta: VectorizeMeta,
inputs: &[Inputs],
) -> Result<EmbeddingRequest> {
let text_inputs = trim_inputs(inputs);
let job_params: JobParams = serde_json::from_value(vect_meta.params.clone())?;
let payload = EmbeddingPayload {
input: text_inputs,
model: OPENAI_EMBEDDING_MODEL.to_owned(),
};

let apikey = match job_params.api_key {
Some(k) => k,
None => {
Expand Down
8 changes: 4 additions & 4 deletions src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ pub const VECTORIZE_SCHEMA: &str = "vectorize";
#[allow(non_camel_case_types)]
#[derive(Clone, Debug, Serialize, Deserialize, Eq, Hash, PartialEq, PostgresEnum)]
pub enum Transformer {
openai,
text_embedding_ada_002,
all_MiniLM_L12_v2,
}

Expand All @@ -17,7 +17,7 @@ impl FromStr for Transformer {

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"openai" => Ok(Transformer::openai),
"text_embedding_ada_002" => Ok(Transformer::text_embedding_ada_002),
"all_MiniLM_L12_v2" => Ok(Transformer::all_MiniLM_L12_v2),
_ => Err(format!("Invalid value: {}", s)),
}
Expand All @@ -27,7 +27,7 @@ impl FromStr for Transformer {
impl From<String> for Transformer {
fn from(s: String) -> Self {
match s.as_str() {
"openai" => Transformer::openai,
"text_embedding_ada_002" => Transformer::text_embedding_ada_002,
"all_MiniLM_L12_v2" => Transformer::all_MiniLM_L12_v2,
_ => panic!("Invalid value for Transformer: {}", s), // or handle this case differently
}
Expand All @@ -37,7 +37,7 @@ impl From<String> for Transformer {
impl Display for Transformer {
fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
match self {
Transformer::openai => write!(f, "openai"),
Transformer::text_embedding_ada_002 => write!(f, "text_embedding_ada_002"),
Transformer::all_MiniLM_L12_v2 => write!(f, "all_MiniLM_L12_v2"),
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/workers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,9 @@ async fn execute_job(dbclient: Pool<Postgres>, msg: Message<JobMessage>) -> Resu
let job_params: types::JobParams = serde_json::from_value(job_meta.params.clone())?;

let embedding_request = match job_meta.transformer {
types::Transformer::openai => {
types::Transformer::text_embedding_ada_002 => {
log!("pg-vectorize: OpenAI transformer");
openai::prepare_openai_request(job_params.clone(), &msg.message.inputs)
openai::prepare_openai_request(job_meta.clone(), &msg.message.inputs)
}
types::Transformer::all_MiniLM_L12_v2 => {
generic::prepare_generic_embedding_request(job_meta.clone(), &msg.message.inputs)
Expand Down
3 changes: 1 addition & 2 deletions src/workers/pg_bgw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,8 @@ pub extern "C" fn background_worker_main(_arg: pg_sys::Datum) {

log!("Starting BG Workers {}", BackgroundWorker::get_name(),);

// bgw only supports the OpenAI transformer case
let oai_q = QUEUE_MAPPING
.get(&Transformer::openai)
.get(&Transformer::text_embedding_ada_002)
.expect("invalid transformer");
let aux_q = QUEUE_MAPPING
.get(&Transformer::all_MiniLM_L12_v2)
Expand Down

0 comments on commit 9639986

Please sign in to comment.