Skip to content

Commit

Permalink
variable expansion in svc url (#30)
Browse files Browse the repository at this point in the history
* add transform to sql api

* interpolate vars

* support variable expansion in svc url
  • Loading branch information
ChuckHend authored Dec 9, 2023
1 parent e425020 commit 2136cf2
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 5 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,11 @@ CREATE EXTENSION vectorize CASCADE;
If you're installing in an existing Postgres instance, you will need the following depdencies:

Rust:

- [pgrx toolchain](https://github.com/pgcentralfoundation/pgrx)

Postgres Extensions:

- [pg_cron](https://github.com/citusdata/pg_cron) == 1.5
- [pgmq](https://github.com/tembo-io/pgmq) >= 0.30.0
- [pgvector](https://github.com/pgvector/pgvector) >= 1.5.0
Expand Down Expand Up @@ -66,7 +68,6 @@ Create a job to vectorize the products table. We'll specify the tables primary k
ALTER SYSTEM SET vectorize.openai_key TO '<your api key>';
```


```sql
SELECT vectorize.table(
job_name => 'product_search',
Expand Down
109 changes: 107 additions & 2 deletions src/transformers/generic.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use anyhow::{Context, Result};
use lazy_static::lazy_static;
use regex::Regex;

use crate::{
executor::VectorizeMeta,
Expand All @@ -8,6 +10,49 @@ use crate::{

use super::openai::trim_inputs;

lazy_static! {
static ref REGEX: Regex = Regex::new(r"\$\{([^}]+)\}").expect("Invalid regex");
}
use std::collections::HashSet;
use std::env;

// finds all placeholders in a string
fn find_placeholders(var: &str) -> Option<Vec<String>> {
let placeholders: HashSet<String> = REGEX
.captures_iter(var)
.filter_map(|cap| cap.get(1))
.map(|match_| match_.as_str().to_owned())
.collect();
if placeholders.is_empty() {
None
} else {
Some(placeholders.into_iter().collect())
}
}

// interpolates a string with given env vars
pub fn interpolate(base_str: &str, env_vars: Vec<String>) -> Result<String> {
let mut interpolated_str = base_str.to_string();
for p in env_vars.iter() {
let env_val = env::var(p).context(format!("failed to get env var: {}", p))?;
interpolated_str = interpolated_str.replace(&format!("${{{}}}", p), &env_val);
}
Ok(interpolated_str)
}

pub fn get_generic_svc_url() -> Result<String> {
if let Some(url) = guc::get_guc(guc::VectorizeGuc::EmbeddingServiceUrl) {
if let Some(phs) = find_placeholders(&url) {
let interpolated = interpolate(&url, phs)?;
Ok(interpolated)
} else {
Ok(url)
}
} else {
Err(anyhow::anyhow!("vectorize.embedding_service_url not set"))
}
}

pub fn prepare_generic_embedding_request(
job_meta: VectorizeMeta,
inputs: &[Inputs],
Expand All @@ -18,12 +63,72 @@ pub fn prepare_generic_embedding_request(
model: job_meta.transformer.to_string(),
};

let svc_host = guc::get_guc(guc::VectorizeGuc::EmbeddingServiceUrl)
.context("vectorize.embedding_Service_url is not set")?;
let svc_host = get_generic_svc_url().context("failed to get embedding service url from GUC")?;

Ok(EmbeddingRequest {
url: svc_host,
payload,
api_key: None,
})
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_find_placeholders() {
let base_str = "http://${TEST_ENV_0}/test/${TEST_ENV_1}";
let placeholders = find_placeholders(base_str).unwrap();
assert!(placeholders.contains(&"TEST_ENV_0".to_owned()));
assert!(placeholders.contains(&"TEST_ENV_1".to_owned()));

// no placeholders
let base_str = "http://TEST_ENV_0/test/TEST_ENV_1";
let placeholders = find_placeholders(base_str);
assert!(placeholders.is_none());
}

#[test]
fn test_interpolate() {
env::set_var("TEST_ENV_0", "A");
env::set_var("TEST_ENV_1", "B");
let base_str = "http://${TEST_ENV_0}/test/${TEST_ENV_1}";
let interpolated = interpolate(
base_str,
vec!["TEST_ENV_0".to_string(), "TEST_ENV_1".to_string()],
)
.unwrap();
assert_eq!(interpolated, "http://A/test/B");

// change order
let base_str = "http://${TEST_ENV_1}/test/${TEST_ENV_0}";
let interpolated = interpolate(
base_str,
vec!["TEST_ENV_0".to_string(), "TEST_ENV_1".to_string()],
)
.unwrap();
assert_eq!(interpolated, "http://B/test/A");

// repeated str
let base_str = "http://${TEST_ENV_0}/test/${TEST_ENV_1}/${TEST_ENV_0}";
let interpolated = interpolate(
base_str,
vec!["TEST_ENV_0".to_string(), "TEST_ENV_1".to_string()],
)
.unwrap();
assert_eq!(interpolated, "http://A/test/B/A");

// missing env var should err
let base_str = "http://${TEST_ENV_0}/test/${TEST_ENV_1}/${TEST_ENV_2}";
let interpolated = interpolate(
base_str,
vec![
"TEST_ENV_0".to_string(),
"TEST_ENV_1".to_string(),
"TEST_ENV_2".to_string(),
],
);
assert!(interpolated.is_err());
}
}
4 changes: 2 additions & 2 deletions src/transformers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ pub mod types;

use crate::guc;
use crate::types::Transformer;
use generic::get_generic_svc_url;
use http_handler::openai_embedding_request;
use openai::{OPENAI_EMBEDDING_MODEL, OPENAI_EMBEDDING_URL};
use pgrx::prelude::*;
Expand Down Expand Up @@ -41,8 +42,7 @@ pub fn transform(input: &str, transformer: Transformer, api_key: Option<String>)
}
}
Transformer::all_MiniLM_L12_v2 => {
let url: String = guc::get_guc(guc::VectorizeGuc::EmbeddingServiceUrl)
.expect("failed to get embedding service url from GUC");
let url = get_generic_svc_url().expect("failed to get embedding service url from GUC");
let embedding_request = EmbeddingPayload {
input: vec![input.to_string()],
model: transformer.to_string(),
Expand Down

0 comments on commit 2136cf2

Please sign in to comment.