Skip to content

Commit

Permalink
handle portkey virtkey in bgw (#138)
Browse files Browse the repository at this point in the history
* handle portkey virtkey in bgw

* bump to 0.18.1

* unused code

* move vectorscake statement to db init

* add log on test

* setup

* run on ubuntu 24

* job
  • Loading branch information
ChuckHend authored Aug 26, 2024
1 parent 8c527c2 commit a7db92a
Show file tree
Hide file tree
Showing 10 changed files with 70 additions and 51 deletions.
22 changes: 12 additions & 10 deletions .github/workflows/extension_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ on:
jobs:
dependencies:
name: Install dependencies
runs-on: ubuntu-22.04
runs-on: ubuntu-24.04
steps:
- uses: actions/checkout@v2

Expand Down Expand Up @@ -85,7 +85,7 @@ jobs:
test:
name: Run tests
needs: dependencies
runs-on: ubuntu-22.04
runs-on: ubuntu-24.04
services:
# Label used to access the service container
vector-serve:
Expand All @@ -100,10 +100,15 @@ jobs:
toolchain: stable
- uses: Swatinem/rust-cache@v2
with:
prefix-key: "pg-vectorize-extension-test"
workspaces: pg-vectorize
prefix-key: "extension-test"
workspaces: |
vectorize
# Additional directories to cache
cache-directories: /home/runner/.pgrx
cache-directories: |
/home/runner/.pgrx
- name: Install sys dependencies
run: |
sudo apt-get update && sudo apt-get install -y postgresql-server-dev-16 libopenblas-dev libreadline-dev
- uses: ./.github/actions/pgx-init
with:
working-directory: ./extension
Expand All @@ -126,10 +131,7 @@ jobs:
${{ runner.os }}-bins-
- name: setup-tests
run: |
make trunk-dependencies
make setup.urls
make setup.shared_preload_libraries
rm -rf ./target/pgrx-test-data-* || true
make setup
- name: unit-test
run: |
make test-unit
Expand All @@ -146,7 +148,7 @@ jobs:
publish:
if: github.event_name == 'release'
name: trunk publish
runs-on: ubuntu-22.04
runs-on: ubuntu-24.04
strategy:
matrix:
pg-version: [14, 15, 16]
Expand Down
1 change: 1 addition & 0 deletions core/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ pub struct JobParams {
pub api_key: Option<String>,
#[serde(default = "default_schedule")]
pub schedule: String,
pub args: Option<serde_json::Value>,
}

fn default_schedule() -> String {
Expand Down
8 changes: 7 additions & 1 deletion core/src/worker/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,17 @@ async fn execute_job(
let job_meta: VectorizeMeta = msg.message.job_meta;
let job_params: JobParams = serde_json::from_value(job_meta.params.clone())?;

let virtual_key = if let Some(args) = job_params.args.clone() {
args.get("virtual_key").map(|v| v.to_string())
} else {
None
};

let provider = providers::get_provider(
&job_meta.transformer.source,
job_params.api_key.clone(),
None,
None,
virtual_key,
)?;

let embedding_request =
Expand Down
2 changes: 1 addition & 1 deletion extension/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "vectorize"
version = "0.18.0"
version = "0.18.1"
edition = "2021"
publish = false

Expand Down
4 changes: 2 additions & 2 deletions extension/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,10 @@ install-pgvector:
install-pgmq:
git clone https://github.com/tembo-io/pgmq.git && \
cd pgmq/pgmq-extension && \
PG_CONFIG=${PGRX_PG_CONFIG} make && \
PG_CONFIG=${PGRX_PG_CONFIG} make clean && \
PG_CONFIG=${PGRX_PG_CONFIG} make && \
PG_CONFIG=${PGRX_PG_CONFIG} make install && \
cd .. && rm -rf pgmq
cd ../.. && rm -rf pgmq

install-vectorscale:
@ARCH=$$(uname -m); \
Expand Down
2 changes: 1 addition & 1 deletion extension/Trunk.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ description = "The simplest way to orchestrate vector search on Postgres."
homepage = "https://github.com/tembo-io/pg_vectorize"
documentation = "https://github.com/tembo-io/pg_vectorize"
categories = ["orchestration", "machine_learning"]
version = "0.18.0"
version = "0.18.1"
loadable_libraries = [{ library_name = "vectorize", requires_restart = true }]

[build]
Expand Down
Empty file.
59 changes: 33 additions & 26 deletions extension/src/search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,38 +39,19 @@ pub fn init_table(
init::init_pgmq()?;

let guc_configs = get_guc_configs(&transformer.source);
let provider = get_provider(
&transformer.source,
guc_configs.api_key.clone(),
guc_configs.service_url,
None,
)?;

//synchronous
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_io()
.enable_time()
.build()
.unwrap_or_else(|e| error!("failed to initialize tokio runtime: {}", e));
let model_dim =
match runtime.block_on(async { provider.model_dim(&transformer.api_name()).await }) {
Ok(e) => e,
Err(e) => {
error!("error getting model dim: {}", e);
}
};

// validate API key where necessary
info!("guc_configs: {:?}", guc_configs);
// validate API key where necessary and collect any optional arguments
// certain embedding services require an API key, e.g. openAI
// key can be set in a GUC, so if its required but not provided in args, and not in GUC, error
match transformer.source {
let optional_args = match transformer.source {
ModelSource::OpenAI => {
openai::validate_api_key(
&guc_configs
.api_key
.clone()
.context("OpenAI key is required")?,
)?;
None
}
ModelSource::Tembo => {
error!("Tembo not implemented for search yet");
Expand All @@ -85,15 +66,40 @@ pub fn init_table(
let res = check_model_host(&url);
match res {
Ok(_) => {
info!("Model host active!")
info!("Model host active!");
None
}
Err(e) => {
error!("Error with model host: {:?}", e)
}
}
}
_ => (),
}
ModelSource::Portkey => Some(serde_json::json!({
"virtual_key": guc_configs.virtual_key.clone().expect("Portkey virtual key is required")
})),
_ => None,
};

let provider = get_provider(
&transformer.source,
guc_configs.api_key.clone(),
guc_configs.service_url.clone(),
guc_configs.virtual_key.clone(),
)?;

// synchronous
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_io()
.enable_time()
.build()
.unwrap_or_else(|e| error!("failed to initialize tokio runtime: {}", e));
let model_dim =
match runtime.block_on(async { provider.model_dim(&transformer.api_name()).await }) {
Ok(e) => e,
Err(e) => {
error!("error getting model dim: {}", e);
}
};

let valid_params = types::JobParams {
schema: schema.to_string(),
Expand All @@ -105,6 +111,7 @@ pub fn init_table(
pkey_type,
api_key: guc_configs.api_key.clone(),
schedule: schedule.to_string(),
args: optional_args,
};
let params =
pgrx::JsonB(serde_json::to_value(valid_params.clone()).expect("error serializing params"));
Expand Down
14 changes: 8 additions & 6 deletions extension/tests/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -788,10 +788,6 @@ async fn test_diskann_cosine() {
common::init_test_table(&test_table_name, &conn).await;
let job_name = format!("job_diskann_{}", test_num);

let _ = sqlx::query("CREATE EXTENSION IF NOT EXISTS vectorscale;")
.execute(&conn)
.await;

common::init_embedding_svc_url(&conn).await;
// initialize a job
let result = sqlx::query(&format!(
Expand All @@ -810,9 +806,15 @@ async fn test_diskann_cosine() {
assert!(result.is_ok());

let search_results: Vec<common::SearchJSON> =
util::common::search_with_retry(&conn, "mobile devices", &job_name, 10, 2, 3, None)
match util::common::search_with_retry(&conn, "mobile devices", &job_name, 10, 2, 3, None)
.await
.unwrap();
{
Ok(results) => results,
Err(e) => {
eprintln!("Error: {:?}", e);
panic!("failed to exec search on diskann");
}
};
assert_eq!(search_results.len(), 3);
}

Expand Down
9 changes: 5 additions & 4 deletions extension/tests/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ pub mod common {
.await
.expect("failed to create extension");

// Optional dependencies
let _ = sqlx::query("CREATE EXTENSION IF NOT EXISTS vectorscale CASCADE")
.execute(&conn)
.await
.expect("failed to create vectorscale extension");
conn
}

Expand All @@ -63,10 +68,6 @@ pub mod common {
28815
} else if cfg!(feature = "pg14") {
28814
} else if cfg!(feature = "pg13") {
28813
} else if cfg!(feature = "pg12") {
28812
} else {
5432
}
Expand Down

0 comments on commit a7db92a

Please sign in to comment.