diff --git a/src/transformers/generic.rs b/src/transformers/generic.rs index 448f561..3c5e1fe 100644 --- a/src/transformers/generic.rs +++ b/src/transformers/generic.rs @@ -1,7 +1,6 @@ use anyhow::{Context, Result}; use lazy_static::lazy_static; use regex::Regex; -use sqlx::Value; use crate::{ executor::VectorizeMeta, @@ -11,9 +10,8 @@ use crate::{ use super::openai::trim_inputs; -// Define a static regex using lazy_static lazy_static! { - static ref REGEX: Regex = Regex::new(r"\{([^}]+)\}").expect("Invalid regex"); + static ref REGEX: Regex = Regex::new(r"\$\{([^}]+)\}").expect("Invalid regex"); } use std::collections::HashSet; use std::env; @@ -21,7 +19,7 @@ use std::env; // finds all placeholders in a string fn find_placeholders(var: &str) -> Option> { let placeholders: HashSet = REGEX - .captures_iter(&var) + .captures_iter(var) .filter_map(|cap| cap.get(1)) .map(|match_| match_.as_str().to_owned()) .collect(); @@ -37,27 +35,21 @@ pub fn interpolate(base_str: &str, env_vars: Vec) -> Result { let mut interpolated_str = base_str.to_string(); for p in env_vars.iter() { let env_val = env::var(p).context(format!("failed to get env var: {}", p))?; - interpolated_str = interpolated_str.replace(&format!("{{{}}}", p), &env_val); + interpolated_str = interpolated_str.replace(&format!("${{{}}}", p), &env_val); } Ok(interpolated_str) } + pub fn get_generic_svc_url() -> Result { if let Some(url) = guc::get_guc(guc::VectorizeGuc::EmbeddingServiceUrl) { if let Some(phs) = find_placeholders(&url) { - // lookup the env vars let interpolated = interpolate(&url, phs)?; - return Ok(interpolated); + Ok(interpolated) } else { - // return anyhow error - Err(anyhow::anyhow!( - "failed to get embedding service url from GUC" - )) + Ok(url) } } else { - // return anyhow error - Err(anyhow::anyhow!( - "failed to get embedding service url from GUC" - )) + Err(anyhow::anyhow!("vectorize.embedding_service_url not set")) } } @@ -71,8 +63,7 @@ pub fn prepare_generic_embedding_request( model: job_meta.transformer.to_string(), }; - let svc_host = guc::get_guc(guc::VectorizeGuc::EmbeddingServiceUrl) - .context("vectorize.embedding_Service_url is not set")?; + let svc_host = get_generic_svc_url().context("failed to get embedding service url from GUC")?; Ok(EmbeddingRequest { url: svc_host, @@ -81,45 +72,63 @@ pub fn prepare_generic_embedding_request( }) } - #[cfg(test)] mod tests { use super::*; #[test] fn test_find_placeholders() { - let base_str = "http://{TEST_ENV_0}/test/{TEST_ENV_1}"; + let base_str = "http://${TEST_ENV_0}/test/${TEST_ENV_1}"; let placeholders = find_placeholders(base_str).unwrap(); - assert_eq!(placeholders, vec!["TEST_ENV_0".to_string(), "TEST_ENV_1".to_string()]); + assert!(placeholders.contains(&"TEST_ENV_0".to_owned())); + assert!(placeholders.contains(&"TEST_ENV_1".to_owned())); // no placeholders let base_str = "http://TEST_ENV_0/test/TEST_ENV_1"; let placeholders = find_placeholders(base_str); assert!(placeholders.is_none()); - } #[test] fn test_interpolate() { env::set_var("TEST_ENV_0", "A"); env::set_var("TEST_ENV_1", "B"); - let base_str = "http://{TEST_ENV_0}/test/{TEST_ENV_1}"; - let interpolated = interpolate(base_str, vec!["TEST_ENV_0".to_string(), "TEST_ENV_1".to_string()]).unwrap(); + let base_str = "http://${TEST_ENV_0}/test/${TEST_ENV_1}"; + let interpolated = interpolate( + base_str, + vec!["TEST_ENV_0".to_string(), "TEST_ENV_1".to_string()], + ) + .unwrap(); assert_eq!(interpolated, "http://A/test/B"); // change order - let base_str = "http://{TEST_ENV_1}/test/{TEST_ENV_0}"; - let interpolated = interpolate(base_str, vec!["TEST_ENV_0".to_string(), "TEST_ENV_1".to_string()]).unwrap(); + let base_str = "http://${TEST_ENV_1}/test/${TEST_ENV_0}"; + let interpolated = interpolate( + base_str, + vec!["TEST_ENV_0".to_string(), "TEST_ENV_1".to_string()], + ) + .unwrap(); assert_eq!(interpolated, "http://B/test/A"); - + // repeated str - let base_str = "http://{TEST_ENV_0}/test/{TEST_ENV_1}/{TEST_ENV_0}"; - let interpolated = interpolate(base_str, vec!["TEST_ENV_0".to_string(), "TEST_ENV_1".to_string()]).unwrap(); + let base_str = "http://${TEST_ENV_0}/test/${TEST_ENV_1}/${TEST_ENV_0}"; + let interpolated = interpolate( + base_str, + vec!["TEST_ENV_0".to_string(), "TEST_ENV_1".to_string()], + ) + .unwrap(); assert_eq!(interpolated, "http://A/test/B/A"); // missing env var should err - let base_str = "http://{TEST_ENV_0}/test/{TEST_ENV_1}/{TEST_ENV_2}"; - let interpolated = interpolate(base_str, vec!["TEST_ENV_0".to_string(), "TEST_ENV_1".to_string(), "TEST_ENV_2".to_string()]); + let base_str = "http://${TEST_ENV_0}/test/${TEST_ENV_1}/${TEST_ENV_2}"; + let interpolated = interpolate( + base_str, + vec![ + "TEST_ENV_0".to_string(), + "TEST_ENV_1".to_string(), + "TEST_ENV_2".to_string(), + ], + ); assert!(interpolated.is_err()); } -} \ No newline at end of file +} diff --git a/src/transformers/mod.rs b/src/transformers/mod.rs index 6b43cd5..c2868c5 100644 --- a/src/transformers/mod.rs +++ b/src/transformers/mod.rs @@ -6,6 +6,7 @@ pub mod types; use crate::guc; use crate::types::Transformer; +use generic::get_generic_svc_url; use http_handler::openai_embedding_request; use openai::{OPENAI_EMBEDDING_MODEL, OPENAI_EMBEDDING_URL}; use pgrx::prelude::*; @@ -41,8 +42,7 @@ pub fn transform(input: &str, transformer: Transformer, api_key: Option) } } Transformer::all_MiniLM_L12_v2 => { - let url: String = guc::get_guc(guc::VectorizeGuc::EmbeddingServiceUrl) - .expect("failed to get embedding service url from GUC"); + let url = get_generic_svc_url().expect("failed to get embedding service url from GUC"); let embedding_request = EmbeddingPayload { input: vec![input.to_string()], model: transformer.to_string(),