From a7db92a040951a689f15771e2129d44c3296a1d4 Mon Sep 17 00:00:00 2001 From: Adam Hendel Date: Mon, 26 Aug 2024 16:10:01 -0500 Subject: [PATCH] handle portkey virtkey in bgw (#138) * handle portkey virtkey in bgw * bump to 0.18.1 * unused code * move vectorscake statement to db init * add log on test * setup * run on ubuntu 24 * job --- .github/workflows/extension_ci.yml | 22 ++++---- core/src/types.rs | 1 + core/src/worker/base.rs | 8 ++- extension/Cargo.toml | 2 +- extension/Makefile | 4 +- extension/Trunk.toml | 2 +- extension/sql/vectorize--0.18.0--0.18.1.sql | 0 extension/src/search.rs | 59 ++++++++++++--------- extension/tests/integration_tests.rs | 14 ++--- extension/tests/util.rs | 9 ++-- 10 files changed, 70 insertions(+), 51 deletions(-) create mode 100644 extension/sql/vectorize--0.18.0--0.18.1.sql diff --git a/.github/workflows/extension_ci.yml b/.github/workflows/extension_ci.yml index 053ba13..f20ba96 100644 --- a/.github/workflows/extension_ci.yml +++ b/.github/workflows/extension_ci.yml @@ -25,7 +25,7 @@ on: jobs: dependencies: name: Install dependencies - runs-on: ubuntu-22.04 + runs-on: ubuntu-24.04 steps: - uses: actions/checkout@v2 @@ -85,7 +85,7 @@ jobs: test: name: Run tests needs: dependencies - runs-on: ubuntu-22.04 + runs-on: ubuntu-24.04 services: # Label used to access the service container vector-serve: @@ -100,10 +100,15 @@ jobs: toolchain: stable - uses: Swatinem/rust-cache@v2 with: - prefix-key: "pg-vectorize-extension-test" - workspaces: pg-vectorize + prefix-key: "extension-test" + workspaces: | + vectorize # Additional directories to cache - cache-directories: /home/runner/.pgrx + cache-directories: | + /home/runner/.pgrx + - name: Install sys dependencies + run: | + sudo apt-get update && sudo apt-get install -y postgresql-server-dev-16 libopenblas-dev libreadline-dev - uses: ./.github/actions/pgx-init with: working-directory: ./extension @@ -126,10 +131,7 @@ jobs: ${{ runner.os }}-bins- - name: setup-tests run: | - make trunk-dependencies - make setup.urls - make setup.shared_preload_libraries - rm -rf ./target/pgrx-test-data-* || true + make setup - name: unit-test run: | make test-unit @@ -146,7 +148,7 @@ jobs: publish: if: github.event_name == 'release' name: trunk publish - runs-on: ubuntu-22.04 + runs-on: ubuntu-24.04 strategy: matrix: pg-version: [14, 15, 16] diff --git a/core/src/types.rs b/core/src/types.rs index 67419d1..3828a1a 100644 --- a/core/src/types.rs +++ b/core/src/types.rs @@ -114,6 +114,7 @@ pub struct JobParams { pub api_key: Option, #[serde(default = "default_schedule")] pub schedule: String, + pub args: Option, } fn default_schedule() -> String { diff --git a/core/src/worker/base.rs b/core/src/worker/base.rs index 2908e1a..928fc0a 100644 --- a/core/src/worker/base.rs +++ b/core/src/worker/base.rs @@ -94,11 +94,17 @@ async fn execute_job( let job_meta: VectorizeMeta = msg.message.job_meta; let job_params: JobParams = serde_json::from_value(job_meta.params.clone())?; + let virtual_key = if let Some(args) = job_params.args.clone() { + args.get("virtual_key").map(|v| v.to_string()) + } else { + None + }; + let provider = providers::get_provider( &job_meta.transformer.source, job_params.api_key.clone(), None, - None, + virtual_key, )?; let embedding_request = diff --git a/extension/Cargo.toml b/extension/Cargo.toml index 2b3be00..2351402 100644 --- a/extension/Cargo.toml +++ b/extension/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "vectorize" -version = "0.18.0" +version = "0.18.1" edition = "2021" publish = false diff --git a/extension/Makefile b/extension/Makefile index fb3888c..77aac9b 100644 --- a/extension/Makefile +++ b/extension/Makefile @@ -72,10 +72,10 @@ install-pgvector: install-pgmq: git clone https://github.com/tembo-io/pgmq.git && \ cd pgmq/pgmq-extension && \ - PG_CONFIG=${PGRX_PG_CONFIG} make && \ PG_CONFIG=${PGRX_PG_CONFIG} make clean && \ + PG_CONFIG=${PGRX_PG_CONFIG} make && \ PG_CONFIG=${PGRX_PG_CONFIG} make install && \ - cd .. && rm -rf pgmq + cd ../.. && rm -rf pgmq install-vectorscale: @ARCH=$$(uname -m); \ diff --git a/extension/Trunk.toml b/extension/Trunk.toml index cd98e29..92e37eb 100644 --- a/extension/Trunk.toml +++ b/extension/Trunk.toml @@ -6,7 +6,7 @@ description = "The simplest way to orchestrate vector search on Postgres." homepage = "https://github.com/tembo-io/pg_vectorize" documentation = "https://github.com/tembo-io/pg_vectorize" categories = ["orchestration", "machine_learning"] -version = "0.18.0" +version = "0.18.1" loadable_libraries = [{ library_name = "vectorize", requires_restart = true }] [build] diff --git a/extension/sql/vectorize--0.18.0--0.18.1.sql b/extension/sql/vectorize--0.18.0--0.18.1.sql new file mode 100644 index 0000000..e69de29 diff --git a/extension/src/search.rs b/extension/src/search.rs index c901f13..db7e7b4 100644 --- a/extension/src/search.rs +++ b/extension/src/search.rs @@ -39,31 +39,11 @@ pub fn init_table( init::init_pgmq()?; let guc_configs = get_guc_configs(&transformer.source); - let provider = get_provider( - &transformer.source, - guc_configs.api_key.clone(), - guc_configs.service_url, - None, - )?; - - //synchronous - let runtime = tokio::runtime::Builder::new_current_thread() - .enable_io() - .enable_time() - .build() - .unwrap_or_else(|e| error!("failed to initialize tokio runtime: {}", e)); - let model_dim = - match runtime.block_on(async { provider.model_dim(&transformer.api_name()).await }) { - Ok(e) => e, - Err(e) => { - error!("error getting model dim: {}", e); - } - }; - - // validate API key where necessary + info!("guc_configs: {:?}", guc_configs); + // validate API key where necessary and collect any optional arguments // certain embedding services require an API key, e.g. openAI // key can be set in a GUC, so if its required but not provided in args, and not in GUC, error - match transformer.source { + let optional_args = match transformer.source { ModelSource::OpenAI => { openai::validate_api_key( &guc_configs @@ -71,6 +51,7 @@ pub fn init_table( .clone() .context("OpenAI key is required")?, )?; + None } ModelSource::Tembo => { error!("Tembo not implemented for search yet"); @@ -85,15 +66,40 @@ pub fn init_table( let res = check_model_host(&url); match res { Ok(_) => { - info!("Model host active!") + info!("Model host active!"); + None } Err(e) => { error!("Error with model host: {:?}", e) } } } - _ => (), - } + ModelSource::Portkey => Some(serde_json::json!({ + "virtual_key": guc_configs.virtual_key.clone().expect("Portkey virtual key is required") + })), + _ => None, + }; + + let provider = get_provider( + &transformer.source, + guc_configs.api_key.clone(), + guc_configs.service_url.clone(), + guc_configs.virtual_key.clone(), + )?; + + // synchronous + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_io() + .enable_time() + .build() + .unwrap_or_else(|e| error!("failed to initialize tokio runtime: {}", e)); + let model_dim = + match runtime.block_on(async { provider.model_dim(&transformer.api_name()).await }) { + Ok(e) => e, + Err(e) => { + error!("error getting model dim: {}", e); + } + }; let valid_params = types::JobParams { schema: schema.to_string(), @@ -105,6 +111,7 @@ pub fn init_table( pkey_type, api_key: guc_configs.api_key.clone(), schedule: schedule.to_string(), + args: optional_args, }; let params = pgrx::JsonB(serde_json::to_value(valid_params.clone()).expect("error serializing params")); diff --git a/extension/tests/integration_tests.rs b/extension/tests/integration_tests.rs index 9cee413..9664693 100644 --- a/extension/tests/integration_tests.rs +++ b/extension/tests/integration_tests.rs @@ -788,10 +788,6 @@ async fn test_diskann_cosine() { common::init_test_table(&test_table_name, &conn).await; let job_name = format!("job_diskann_{}", test_num); - let _ = sqlx::query("CREATE EXTENSION IF NOT EXISTS vectorscale;") - .execute(&conn) - .await; - common::init_embedding_svc_url(&conn).await; // initialize a job let result = sqlx::query(&format!( @@ -810,9 +806,15 @@ async fn test_diskann_cosine() { assert!(result.is_ok()); let search_results: Vec = - util::common::search_with_retry(&conn, "mobile devices", &job_name, 10, 2, 3, None) + match util::common::search_with_retry(&conn, "mobile devices", &job_name, 10, 2, 3, None) .await - .unwrap(); + { + Ok(results) => results, + Err(e) => { + eprintln!("Error: {:?}", e); + panic!("failed to exec search on diskann"); + } + }; assert_eq!(search_results.len(), 3); } diff --git a/extension/tests/util.rs b/extension/tests/util.rs index 0a01e46..982c353 100644 --- a/extension/tests/util.rs +++ b/extension/tests/util.rs @@ -53,6 +53,11 @@ pub mod common { .await .expect("failed to create extension"); + // Optional dependencies + let _ = sqlx::query("CREATE EXTENSION IF NOT EXISTS vectorscale CASCADE") + .execute(&conn) + .await + .expect("failed to create vectorscale extension"); conn } @@ -63,10 +68,6 @@ pub mod common { 28815 } else if cfg!(feature = "pg14") { 28814 - } else if cfg!(feature = "pg13") { - 28813 - } else if cfg!(feature = "pg12") { - 28812 } else { 5432 }