Skip to content

Commit

Permalink
support variable expansion in svc url
Browse files Browse the repository at this point in the history
  • Loading branch information
ChuckHend committed Dec 9, 2023
1 parent 23ae8ca commit 4a5023d
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 33 deletions.
71 changes: 40 additions & 31 deletions src/transformers/generic.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use anyhow::{Context, Result};
use lazy_static::lazy_static;
use regex::Regex;
use sqlx::Value;

use crate::{
executor::VectorizeMeta,
Expand All @@ -11,17 +10,16 @@ use crate::{

use super::openai::trim_inputs;

// Define a static regex using lazy_static
lazy_static! {
static ref REGEX: Regex = Regex::new(r"\{([^}]+)\}").expect("Invalid regex");
static ref REGEX: Regex = Regex::new(r"\$\{([^}]+)\}").expect("Invalid regex");
}
use std::collections::HashSet;
use std::env;

// finds all placeholders in a string
fn find_placeholders(var: &str) -> Option<Vec<String>> {
let placeholders: HashSet<String> = REGEX
.captures_iter(&var)
.captures_iter(var)
.filter_map(|cap| cap.get(1))
.map(|match_| match_.as_str().to_owned())
.collect();
Expand All @@ -37,27 +35,21 @@ pub fn interpolate(base_str: &str, env_vars: Vec<String>) -> Result<String> {
let mut interpolated_str = base_str.to_string();
for p in env_vars.iter() {
let env_val = env::var(p).context(format!("failed to get env var: {}", p))?;
interpolated_str = interpolated_str.replace(&format!("{{{}}}", p), &env_val);
interpolated_str = interpolated_str.replace(&format!("${{{}}}", p), &env_val);
}
Ok(interpolated_str)
}

pub fn get_generic_svc_url() -> Result<String> {
if let Some(url) = guc::get_guc(guc::VectorizeGuc::EmbeddingServiceUrl) {
if let Some(phs) = find_placeholders(&url) {
// lookup the env vars
let interpolated = interpolate(&url, phs)?;
return Ok(interpolated);
Ok(interpolated)
} else {
// return anyhow error
Err(anyhow::anyhow!(
"failed to get embedding service url from GUC"
))
Ok(url)
}
} else {
// return anyhow error
Err(anyhow::anyhow!(
"failed to get embedding service url from GUC"
))
Err(anyhow::anyhow!("vectorize.embedding_service_url not set"))
}
}

Expand All @@ -71,8 +63,7 @@ pub fn prepare_generic_embedding_request(
model: job_meta.transformer.to_string(),
};

let svc_host = guc::get_guc(guc::VectorizeGuc::EmbeddingServiceUrl)
.context("vectorize.embedding_Service_url is not set")?;
let svc_host = get_generic_svc_url().context("failed to get embedding service url from GUC")?;

Ok(EmbeddingRequest {
url: svc_host,
Expand All @@ -81,45 +72,63 @@ pub fn prepare_generic_embedding_request(
})
}


#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_find_placeholders() {
let base_str = "http://{TEST_ENV_0}/test/{TEST_ENV_1}";
let base_str = "http://${TEST_ENV_0}/test/${TEST_ENV_1}";
let placeholders = find_placeholders(base_str).unwrap();
assert_eq!(placeholders, vec!["TEST_ENV_0".to_string(), "TEST_ENV_1".to_string()]);
assert!(placeholders.contains(&"TEST_ENV_0".to_owned()));
assert!(placeholders.contains(&"TEST_ENV_1".to_owned()));

// no placeholders
let base_str = "http://TEST_ENV_0/test/TEST_ENV_1";
let placeholders = find_placeholders(base_str);
assert!(placeholders.is_none());

}

#[test]
fn test_interpolate() {
env::set_var("TEST_ENV_0", "A");
env::set_var("TEST_ENV_1", "B");
let base_str = "http://{TEST_ENV_0}/test/{TEST_ENV_1}";
let interpolated = interpolate(base_str, vec!["TEST_ENV_0".to_string(), "TEST_ENV_1".to_string()]).unwrap();
let base_str = "http://${TEST_ENV_0}/test/${TEST_ENV_1}";
let interpolated = interpolate(
base_str,
vec!["TEST_ENV_0".to_string(), "TEST_ENV_1".to_string()],
)
.unwrap();
assert_eq!(interpolated, "http://A/test/B");

// change order
let base_str = "http://{TEST_ENV_1}/test/{TEST_ENV_0}";
let interpolated = interpolate(base_str, vec!["TEST_ENV_0".to_string(), "TEST_ENV_1".to_string()]).unwrap();
let base_str = "http://${TEST_ENV_1}/test/${TEST_ENV_0}";
let interpolated = interpolate(
base_str,
vec!["TEST_ENV_0".to_string(), "TEST_ENV_1".to_string()],
)
.unwrap();
assert_eq!(interpolated, "http://B/test/A");

// repeated str
let base_str = "http://{TEST_ENV_0}/test/{TEST_ENV_1}/{TEST_ENV_0}";
let interpolated = interpolate(base_str, vec!["TEST_ENV_0".to_string(), "TEST_ENV_1".to_string()]).unwrap();
let base_str = "http://${TEST_ENV_0}/test/${TEST_ENV_1}/${TEST_ENV_0}";
let interpolated = interpolate(
base_str,
vec!["TEST_ENV_0".to_string(), "TEST_ENV_1".to_string()],
)
.unwrap();
assert_eq!(interpolated, "http://A/test/B/A");

// missing env var should err
let base_str = "http://{TEST_ENV_0}/test/{TEST_ENV_1}/{TEST_ENV_2}";
let interpolated = interpolate(base_str, vec!["TEST_ENV_0".to_string(), "TEST_ENV_1".to_string(), "TEST_ENV_2".to_string()]);
let base_str = "http://${TEST_ENV_0}/test/${TEST_ENV_1}/${TEST_ENV_2}";
let interpolated = interpolate(
base_str,
vec![
"TEST_ENV_0".to_string(),
"TEST_ENV_1".to_string(),
"TEST_ENV_2".to_string(),
],
);
assert!(interpolated.is_err());
}
}
}
4 changes: 2 additions & 2 deletions src/transformers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ pub mod types;

use crate::guc;
use crate::types::Transformer;
use generic::get_generic_svc_url;
use http_handler::openai_embedding_request;
use openai::{OPENAI_EMBEDDING_MODEL, OPENAI_EMBEDDING_URL};
use pgrx::prelude::*;
Expand Down Expand Up @@ -41,8 +42,7 @@ pub fn transform(input: &str, transformer: Transformer, api_key: Option<String>)
}
}
Transformer::all_MiniLM_L12_v2 => {
let url: String = guc::get_guc(guc::VectorizeGuc::EmbeddingServiceUrl)
.expect("failed to get embedding service url from GUC");
let url = get_generic_svc_url().expect("failed to get embedding service url from GUC");
let embedding_request = EmbeddingPayload {
input: vec![input.to_string()],
model: transformer.to_string(),
Expand Down

0 comments on commit 4a5023d

Please sign in to comment.