diff --git a/.github/workflows/diesel.yml b/.github/workflows/diesel.yml index 3a4d3d96d..7b77f484b 100644 --- a/.github/workflows/diesel.yml +++ b/.github/workflows/diesel.yml @@ -44,7 +44,7 @@ jobs: steps: - uses: actions/checkout@v3 - uses: dtolnay/rust-toolchain@stable - - run: cargo build --manifest-path sea-query-diesel/Cargo.toml --workspace --features postgres,sqlite,mysql --features=with-chrono,with-json,with-rust_decimal,with-bigdecimal,with-uuid,with-time,with-ipnetwork,with-mac_address,postgres-array + - run: cargo build --manifest-path sea-query-diesel/Cargo.toml --workspace --features postgres,sqlite,mysql --features=with-chrono,with-json,with-rust_decimal,with-bigdecimal,with-uuid,with-time,with-ipnetwork,with-mac_address,postgres-array,postgres-vector - run: cargo build --manifest-path sea-query-diesel/Cargo.toml --workspace --features postgres,sqlite,mysql --features=with-chrono - run: cargo build --manifest-path sea-query-diesel/Cargo.toml --workspace --features postgres,sqlite,mysql --features=with-json - run: cargo build --manifest-path sea-query-diesel/Cargo.toml --workspace --features postgres,sqlite,mysql --features=with-rust_decimal @@ -56,6 +56,7 @@ jobs: - run: cargo build --manifest-path sea-query-diesel/Cargo.toml --workspace --features postgres,sqlite,mysql --features=with-ipnetwork - run: cargo build --manifest-path sea-query-diesel/Cargo.toml --workspace --features postgres,sqlite,mysql --features=with-mac_address - run: cargo build --manifest-path sea-query-diesel/Cargo.toml --workspace --features postgres,sqlite,mysql --features=postgres-array + - run: cargo build --manifest-path sea-query-diesel/Cargo.toml --workspace --features postgres,sqlite,mysql --features=postgres-vector sqlite: name: SQLite diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index b2c763357..890e1f0cb 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -57,7 +57,7 @@ jobs: toolchain: stable components: clippy - run: cargo clippy --features=all-features --workspace -- -D warnings - - run: cargo clippy --manifest-path sea-query-binder/Cargo.toml --workspace --features runtime-async-std-rustls --features=with-chrono,with-json,with-rust_decimal,with-bigdecimal,with-uuid,with-time,with-ipnetwork,with-mac_address,postgres-array -- -D warnings + - run: cargo clippy --manifest-path sea-query-binder/Cargo.toml --workspace --features runtime-async-std-rustls --features=with-chrono,with-json,with-rust_decimal,with-bigdecimal,with-uuid,with-time,with-ipnetwork,with-mac_address,postgres-array,postgres-vector -- -D warnings - run: cargo clippy --manifest-path sea-query-rusqlite/Cargo.toml --all-features --workspace -- -D warnings - run: cargo clippy --manifest-path sea-query-postgres/Cargo.toml --all-features --workspace -- -D warnings @@ -78,6 +78,7 @@ jobs: - run: cargo build --workspace --features=with-ipnetwork - run: cargo build --workspace --features=with-mac_address - run: cargo build --workspace --features=postgres-array + - run: cargo build --workspace --features=postgres-vector - run: cargo build --workspace --features=thread-safe binder-build: @@ -91,7 +92,7 @@ jobs: steps: - uses: actions/checkout@v3 - uses: dtolnay/rust-toolchain@stable - - run: cargo build --manifest-path sea-query-binder/Cargo.toml --workspace --features sqlx-postgres,sqlx-sqlite,sqlx-any,sqlx-mysql --features=runtime-${{ matrix.runtime }}-${{ matrix.tls }} --features=with-chrono,with-json,with-rust_decimal,with-bigdecimal,with-uuid,with-time,with-ipnetwork,with-mac_address,postgres-array + - run: cargo build --manifest-path sea-query-binder/Cargo.toml --workspace --features sqlx-postgres,sqlx-sqlite,sqlx-any,sqlx-mysql --features=runtime-${{ matrix.runtime }}-${{ matrix.tls }} --features=with-chrono,with-json,with-rust_decimal,with-bigdecimal,with-uuid,with-time,with-ipnetwork,with-mac_address,postgres-array,postgres-vector - run: cargo build --manifest-path sea-query-binder/Cargo.toml --workspace --features sqlx-postgres,sqlx-sqlite,sqlx-any,sqlx-mysql --features=runtime-${{ matrix.runtime }}-${{ matrix.tls }} --features=with-chrono - run: cargo build --manifest-path sea-query-binder/Cargo.toml --workspace --features sqlx-postgres,sqlx-sqlite,sqlx-any,sqlx-mysql --features=runtime-${{ matrix.runtime }}-${{ matrix.tls }} --features=with-json - run: cargo build --manifest-path sea-query-binder/Cargo.toml --workspace --features sqlx-postgres,sqlx-sqlite,sqlx-any,sqlx-mysql --features=runtime-${{ matrix.runtime }}-${{ matrix.tls }} --features=with-rust_decimal @@ -101,6 +102,7 @@ jobs: - run: cargo build --manifest-path sea-query-binder/Cargo.toml --workspace --features sqlx-postgres,sqlx-sqlite,sqlx-any,sqlx-mysql --features=runtime-${{ matrix.runtime }}-${{ matrix.tls }} --features=with-ipnetwork - run: cargo build --manifest-path sea-query-binder/Cargo.toml --workspace --features sqlx-postgres,sqlx-sqlite,sqlx-any,sqlx-mysql --features=runtime-${{ matrix.runtime }}-${{ matrix.tls }} --features=with-mac_address - run: cargo build --manifest-path sea-query-binder/Cargo.toml --workspace --features sqlx-postgres,sqlx-sqlite,sqlx-any,sqlx-mysql --features=runtime-${{ matrix.runtime }}-${{ matrix.tls }} --features=postgres-array + - run: cargo build --manifest-path sea-query-binder/Cargo.toml --workspace --features sqlx-postgres,sqlx-sqlite,sqlx-any,sqlx-mysql --features=runtime-${{ matrix.runtime }}-${{ matrix.tls }} --features=postgres-vector rusqlite-build: name: Build `sea-query-rusqlite` @@ -108,7 +110,7 @@ jobs: steps: - uses: actions/checkout@v3 - uses: dtolnay/rust-toolchain@stable - - run: cargo build --manifest-path sea-query-rusqlite/Cargo.toml --workspace --features=with-chrono,with-json,with-rust_decimal,with-bigdecimal,with-uuid,with-time,with-ipnetwork,with-mac_address,postgres-array + - run: cargo build --manifest-path sea-query-rusqlite/Cargo.toml --workspace --features=with-chrono,with-json,with-rust_decimal,with-bigdecimal,with-uuid,with-time,with-ipnetwork,with-mac_address,postgres-array,postgres-vector - run: cargo build --manifest-path sea-query-rusqlite/Cargo.toml --workspace --features=with-chrono - run: cargo build --manifest-path sea-query-rusqlite/Cargo.toml --workspace --features=with-json - run: cargo build --manifest-path sea-query-rusqlite/Cargo.toml --workspace --features=with-rust_decimal @@ -118,6 +120,7 @@ jobs: - run: cargo build --manifest-path sea-query-rusqlite/Cargo.toml --workspace --features=with-ipnetwork - run: cargo build --manifest-path sea-query-rusqlite/Cargo.toml --workspace --features=with-mac_address - run: cargo build --manifest-path sea-query-rusqlite/Cargo.toml --workspace --features=postgres-array + - run: cargo build --manifest-path sea-query-rusqlite/Cargo.toml --workspace --features=postgres-vector postgres-build: name: Build `sea-query-postgres` @@ -125,7 +128,7 @@ jobs: steps: - uses: actions/checkout@v3 - uses: dtolnay/rust-toolchain@stable - - run: cargo build --manifest-path sea-query-postgres/Cargo.toml --workspace --features=with-chrono,with-json,with-rust_decimal,with-bigdecimal,with-uuid,with-time,with-ipnetwork,with-mac_address,postgres-array + - run: cargo build --manifest-path sea-query-postgres/Cargo.toml --workspace --features=with-chrono,with-json,with-rust_decimal,with-bigdecimal,with-uuid,with-time,with-ipnetwork,with-mac_address,postgres-array,postgres-vector - run: cargo build --manifest-path sea-query-postgres/Cargo.toml --workspace --features=with-chrono - run: cargo build --manifest-path sea-query-postgres/Cargo.toml --workspace --features=with-json - run: cargo build --manifest-path sea-query-postgres/Cargo.toml --workspace --features=with-rust_decimal @@ -135,6 +138,7 @@ jobs: - run: cargo build --manifest-path sea-query-postgres/Cargo.toml --workspace --features=with-ipnetwork - run: cargo build --manifest-path sea-query-postgres/Cargo.toml --workspace --features=with-mac_address - run: cargo build --manifest-path sea-query-postgres/Cargo.toml --workspace --features=postgres-array + - run: cargo build --manifest-path sea-query-postgres/Cargo.toml --workspace --features=postgres-vector test: name: Unit Test diff --git a/Cargo.toml b/Cargo.toml index 12848516b..a92348135 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,6 +35,7 @@ serde_json = { version = "1", default-features = false, optional = true, feature educe = { version = "=0.5.11", default-features = false, optional = true, features = ["Hash", "PartialEq", "Eq"] } chrono = { version = "0.4.27", default-features = false, optional = true, features = ["clock"] } postgres-types = { version = "0", default-features = false, optional = true } +pgvector = { version = "<0.4", default-features = false, optional = true } rust_decimal = { version = "1", default-features = false, optional = true } bigdecimal = { version = "0.4", default-features = false, optional = true } uuid = { version = "1", default-features = false, optional = true } @@ -58,6 +59,7 @@ derive = ["sea-query-derive"] attr = ["sea-query-attr"] hashable-value = ["educe", "ordered-float"] postgres-array = [] +postgres-vector = ["pgvector"] postgres-interval = [] thread-safe = [] with-chrono = ["chrono"] @@ -80,6 +82,7 @@ all-features = [ "all-types", ] # everything except option-* all-types = [ + "postgres-vector", "postgres-array", "postgres-interval", "with-chrono", diff --git a/sea-query-binder/Cargo.toml b/sea-query-binder/Cargo.toml index 47b5ddcd0..1ed18e292 100644 --- a/sea-query-binder/Cargo.toml +++ b/sea-query-binder/Cargo.toml @@ -27,6 +27,7 @@ uuid = { version = "1", default-features = false, optional = true } time = { version = "0.3.36", default-features = false, optional = true, features = ["macros", "formatting"] } ipnetwork = { version = "0.20", default-features = false, optional = true } mac_address = { version = "1.1", default-features = false, optional = true } +pgvector = { version = "<0.4", default-features = false, optional = true } [features] sqlx-mysql = ["sqlx/mysql"] @@ -42,6 +43,7 @@ with-time = ["sqlx?/time", "sea-query/with-time", "time"] with-ipnetwork = ["sqlx?/ipnetwork", "sea-query/with-ipnetwork", "ipnetwork"] with-mac_address = ["sqlx?/mac_address", "sea-query/with-mac_address", "mac_address"] postgres-array = ["sea-query/postgres-array"] +postgres-vector = ["sea-query/postgres-vector", "pgvector/sqlx"] runtime-async-std = ["sqlx?/runtime-async-std"] runtime-async-std-native-tls = ["sqlx?/runtime-async-std-native-tls"] runtime-async-std-rustls = ["sqlx?/runtime-async-std-rustls", ] diff --git a/sea-query-binder/src/sqlx_any.rs b/sea-query-binder/src/sqlx_any.rs index 5a73fca22..263ec8b9d 100644 --- a/sea-query-binder/src/sqlx_any.rs +++ b/sea-query-binder/src/sqlx_any.rs @@ -120,6 +120,10 @@ impl<'q> sqlx::IntoArguments<'q, sqlx::any::Any> for SqlxValues { Value::Array(_, _) => { panic!("SQLx doesn't support array arguments for Any"); } + #[cfg(feature = "postgres-vector")] + Value::Vector(_) => { + panic!("SQLx doesn't support vector arguments for Any"); + } } } args diff --git a/sea-query-binder/src/sqlx_mysql.rs b/sea-query-binder/src/sqlx_mysql.rs index 08c34ba4b..3261ef293 100644 --- a/sea-query-binder/src/sqlx_mysql.rs +++ b/sea-query-binder/src/sqlx_mysql.rs @@ -110,6 +110,10 @@ impl sqlx::IntoArguments<'_, sqlx::mysql::MySql> for SqlxValues { Value::Array(_, _) => { panic!("Mysql doesn't support array arguments"); } + #[cfg(feature = "postgres-vector")] + Value::Vector(_) => { + panic!("Mysql doesn't support vector arguments"); + } #[cfg(feature = "with-ipnetwork")] Value::IpNetwork(_) => { panic!("Mysql doesn't support IpNetwork arguments"); diff --git a/sea-query-binder/src/sqlx_postgres.rs b/sea-query-binder/src/sqlx_postgres.rs index 12c4e70ba..314a7e44c 100644 --- a/sea-query-binder/src/sqlx_postgres.rs +++ b/sea-query-binder/src/sqlx_postgres.rs @@ -314,6 +314,10 @@ impl sqlx::IntoArguments<'_, sqlx::postgres::Postgres> for SqlxValues { let _ = args.add(value); } }, + #[cfg(feature = "postgres-vector")] + Value::Vector(v) => { + let _ = args.add(v.as_deref()); + } } } args diff --git a/sea-query-binder/src/sqlx_sqlite.rs b/sea-query-binder/src/sqlx_sqlite.rs index 14384d7b9..e31751a0b 100644 --- a/sea-query-binder/src/sqlx_sqlite.rs +++ b/sea-query-binder/src/sqlx_sqlite.rs @@ -120,6 +120,10 @@ impl<'q> sqlx::IntoArguments<'q, sqlx::sqlite::Sqlite> for SqlxValues { Value::Array(_, _) => { panic!("Sqlite doesn't support array arguments"); } + #[cfg(feature = "postgres-vector")] + Value::Vector(_) => { + panic!("Sqlite doesn't support vector arguments"); + } } } args diff --git a/sea-query-diesel/Cargo.toml b/sea-query-diesel/Cargo.toml index 2f9120812..0e1545515 100644 --- a/sea-query-diesel/Cargo.toml +++ b/sea-query-diesel/Cargo.toml @@ -29,6 +29,7 @@ uuid = { version = "1", default-features = false, optional = true } serde_json = { version = "1", default-features = false, optional = true } ipnetwork = { version = "0.20", default-features = false, optional = true } mac_address = { version = "1.1", default-features = false, optional = true } +pgvector = { version = "<0.4", default-features = false, optional = true } [features] default = [] @@ -64,3 +65,4 @@ with-ipnetwork = [ ] with-mac_address = ["sea-query/with-mac_address", "mac_address"] postgres-array = ["sea-query/postgres-array"] +postgres-vector = ["sea-query/postgres-vector", "pgvector/diesel"] diff --git a/sea-query-diesel/src/backend/mysql.rs b/sea-query-diesel/src/backend/mysql.rs index ee8ccd0bd..631eeaadb 100644 --- a/sea-query-diesel/src/backend/mysql.rs +++ b/sea-query-diesel/src/backend/mysql.rs @@ -73,6 +73,8 @@ impl TransformValue for Mysql { Value::MacAddress(_) => bail!("Mysql doesn't support MacAddress arguments"), #[cfg(feature = "postgres-array")] Value::Array(_, _) => bail!("Mysql doesn't support array arguments"), + #[cfg(feature = "postgres-vector")] + Value::Vector(_) => bail!("Mysql doesn't support vector arguments"), }; Ok(transformed) } diff --git a/sea-query-diesel/src/backend/postgres.rs b/sea-query-diesel/src/backend/postgres.rs index ebf634a9d..a0bbfd425 100644 --- a/sea-query-diesel/src/backend/postgres.rs +++ b/sea-query-diesel/src/backend/postgres.rs @@ -205,6 +205,8 @@ impl TransformValue for Pg { ) } }, + #[cfg(feature = "postgres-vector")] + Value::Vector(v) => build!(pgvector::sql_types::Vector, v.map(|v| *v)), }; Ok(transformed) } diff --git a/sea-query-diesel/src/backend/sqlite.rs b/sea-query-diesel/src/backend/sqlite.rs index 45d1a8462..e101d0228 100644 --- a/sea-query-diesel/src/backend/sqlite.rs +++ b/sea-query-diesel/src/backend/sqlite.rs @@ -98,6 +98,8 @@ impl TransformValue for Sqlite { Value::MacAddress(_) => bail!("Sqlite doesn't support MacAddress arguments"), #[cfg(feature = "postgres-array")] Value::Array(_, _) => bail!("Sqlite doesn't support array arguments"), + #[cfg(feature = "postgres-vector")] + Value::Vector(_) => bail!("Sqlite doesn't support vector arguments"), }; Ok(transformed) } diff --git a/sea-query-postgres/Cargo.toml b/sea-query-postgres/Cargo.toml index 78f115221..84cd8ab4e 100644 --- a/sea-query-postgres/Cargo.toml +++ b/sea-query-postgres/Cargo.toml @@ -19,6 +19,7 @@ rust-version = "1.60" [dependencies] sea-query = { version = "0.31.0", path = "..", default-features = false } postgres-types = { version = "0.2", default-features = false } +pgvector = { version = "<0.4", default-features = false, optional = true } bytes = { version = "1", default-features = false } rust_decimal = { version = "1", default-features = false, optional = true } bigdecimal = { version = "0.4", default-features = false, optional = true } @@ -35,5 +36,6 @@ with-bigdecimal = ["sea-query/with-bigdecimal", "bigdecimal"] with-uuid = ["postgres-types/with-uuid-1", "sea-query/with-uuid"] with-time = ["postgres-types/with-time-0_3", "sea-query/with-time"] postgres-array = ["postgres-types/array-impls", "sea-query/postgres-array"] +postgres-vector = ["sea-query/postgres-vector", "pgvector/postgres"] with-ipnetwork = ["postgres-types/with-cidr-0_2", "sea-query/with-ipnetwork", "ipnetwork", "cidr"] with-mac_address = ["postgres-types/with-eui48-1", "sea-query/with-mac_address", "mac_address", "eui48"] diff --git a/sea-query-postgres/src/lib.rs b/sea-query-postgres/src/lib.rs index a5cdf6757..130567457 100644 --- a/sea-query-postgres/src/lib.rs +++ b/sea-query-postgres/src/lib.rs @@ -116,6 +116,10 @@ impl ToSql for PostgresValue { .to_sql(ty, out), #[cfg(feature = "postgres-array")] Value::Array(_, None) => Ok(IsNull::Yes), + #[cfg(feature = "postgres-vector")] + Value::Vector(Some(v)) => v.to_sql(ty, out), + #[cfg(feature = "postgres-vector")] + Value::Vector(None) => Ok(IsNull::Yes), #[cfg(feature = "with-ipnetwork")] Value::IpNetwork(v) => { use cidr::IpCidr; diff --git a/sea-query-rusqlite/Cargo.toml b/sea-query-rusqlite/Cargo.toml index 1806292b7..4c85f7e87 100644 --- a/sea-query-rusqlite/Cargo.toml +++ b/sea-query-rusqlite/Cargo.toml @@ -30,3 +30,4 @@ with-time = ["rusqlite/time", "sea-query/with-time"] with-ipnetwork = ["sea-query/with-ipnetwork"] with-mac_address = ["sea-query/with-mac_address"] postgres-array = ["sea-query/postgres-array"] +postgres-vector = ["sea-query/postgres-vector"] diff --git a/sea-query-rusqlite/src/lib.rs b/sea-query-rusqlite/src/lib.rs index 286ac7868..03d9918ad 100644 --- a/sea-query-rusqlite/src/lib.rs +++ b/sea-query-rusqlite/src/lib.rs @@ -130,6 +130,10 @@ impl ToSql for RusqliteValue { Value::Array(_, _) => { panic!("Rusqlite doesn't support Array arguments"); } + #[cfg(feature = "postgres-vector")] + Value::Vector(_) => { + panic!("Rusqlite doesn't support Vector arguments"); + } } } } diff --git a/src/backend/mysql/table.rs b/src/backend/mysql/table.rs index ec3594122..0da5ef304 100644 --- a/src/backend/mysql/table.rs +++ b/src/backend/mysql/table.rs @@ -90,6 +90,7 @@ impl TableBuilder for MysqlQueryBuilder { .join("', '") ), ColumnType::Array(_) => unimplemented!("Array is not available in MySQL."), + ColumnType::Vector(_) => unimplemented!("Vector is not available in MySQL."), ColumnType::Cidr => unimplemented!("Cidr is not available in MySQL."), ColumnType::Inet => unimplemented!("Inet is not available in MySQL."), ColumnType::MacAddr => unimplemented!("MacAddr is not available in MySQL."), diff --git a/src/backend/postgres/query.rs b/src/backend/postgres/query.rs index f7cce9d0d..fa48aa68d 100644 --- a/src/backend/postgres/query.rs +++ b/src/backend/postgres/query.rs @@ -90,6 +90,12 @@ impl QueryBuilder for PostgresQueryBuilder { PgBinOper::CastJsonField => "->>", PgBinOper::Regex => "~", PgBinOper::RegexCaseInsensitive => "~*", + #[cfg(feature = "postgres-vector")] + PgBinOper::EuclideanDistance => "<->", + #[cfg(feature = "postgres-vector")] + PgBinOper::NegativeInnerProduct => "<#>", + #[cfg(feature = "postgres-vector")] + PgBinOper::CosineDistance => "<=>", } ) .unwrap(), diff --git a/src/backend/postgres/table.rs b/src/backend/postgres/table.rs index c45dbc7eb..a1a3e2f9c 100644 --- a/src/backend/postgres/table.rs +++ b/src/backend/postgres/table.rs @@ -71,6 +71,10 @@ impl TableBuilder for PostgresQueryBuilder { self.prepare_column_type(elem_type, &mut sql); format!("{sql}[]") } + ColumnType::Vector(size) => match size { + Some(size) => format!("vector({size})"), + None => "vector".into(), + }, ColumnType::Custom(iden) => iden.to_string(), ColumnType::Enum { name, .. } => name.to_string(), ColumnType::Cidr => "cidr".into(), diff --git a/src/backend/query_builder.rs b/src/backend/query_builder.rs index e3823eaec..da0eea675 100644 --- a/src/backend/query_builder.rs +++ b/src/backend/query_builder.rs @@ -1045,6 +1045,8 @@ pub trait QueryBuilder: Value::MacAddress(None) => write!(s, "NULL").unwrap(), #[cfg(feature = "postgres-array")] Value::Array(_, None) => write!(s, "NULL").unwrap(), + #[cfg(feature = "postgres-vector")] + Value::Vector(None) => write!(s, "NULL").unwrap(), Value::Bool(Some(b)) => write!(s, "{}", if *b { "TRUE" } else { "FALSE" }).unwrap(), Value::TinyInt(Some(v)) => write!(s, "{v}").unwrap(), Value::SmallInt(Some(v)) => write!(s, "{v}").unwrap(), @@ -1118,6 +1120,17 @@ pub trait QueryBuilder: .join(",") ) .unwrap(), + #[cfg(feature = "postgres-vector")] + Value::Vector(Some(v)) => { + write!(s, "'[").unwrap(); + for (i, &element) in v.as_slice().iter().enumerate() { + if i != 0 { + write!(s, ",").unwrap(); + } + write!(s, "{element}").unwrap(); + } + write!(s, "]'").unwrap(); + } #[cfg(feature = "with-ipnetwork")] Value::IpNetwork(Some(v)) => write!(s, "'{v}'").unwrap(), #[cfg(feature = "with-mac_address")] diff --git a/src/backend/sqlite/table.rs b/src/backend/sqlite/table.rs index 834b07278..429a4572e 100644 --- a/src/backend/sqlite/table.rs +++ b/src/backend/sqlite/table.rs @@ -184,6 +184,7 @@ impl SqliteQueryBuilder { ColumnType::Custom(iden) => iden.to_string(), ColumnType::Enum { .. } => "enum_text".into(), ColumnType::Array(_) => unimplemented!("Array is not available in Sqlite."), + ColumnType::Vector(_) => unimplemented!("Vector is not available in Sqlite."), ColumnType::Cidr => unimplemented!("Cidr is not available in Sqlite."), ColumnType::Inet => unimplemented!("Inet is not available in Sqlite."), ColumnType::MacAddr => unimplemented!("MacAddr is not available in Sqlite."), diff --git a/src/extension/postgres/mod.rs b/src/extension/postgres/mod.rs index 6dd48041f..930b23e7d 100644 --- a/src/extension/postgres/mod.rs +++ b/src/extension/postgres/mod.rs @@ -37,6 +37,12 @@ pub enum PgBinOper { Regex, /// `~*`. Regex operator with case insensitive matching. RegexCaseInsensitive, + #[cfg(feature = "postgres-vector")] + EuclideanDistance, + #[cfg(feature = "postgres-vector")] + NegativeInnerProduct, + #[cfg(feature = "postgres-vector")] + CosineDistance, } impl From for BinOper { diff --git a/src/table/column.rs b/src/table/column.rs index eac4ee725..4f6d2ba9b 100644 --- a/src/table/column.rs +++ b/src/table/column.rs @@ -50,6 +50,7 @@ pub trait IntoColumnDef { /// | Uuid | binary(16) | uuid | uuid_text | /// | Enum | ENUM(...) | ENUM_NAME | enum_text | /// | Array | N/A | DATA_TYPE[] | N/A | +/// | Vector | N/A | vector | N/A | /// | Cidr | N/A | cidr | N/A | /// | Inet | N/A | inet | N/A | /// | MacAddr | N/A | macaddr | N/A | @@ -94,6 +95,7 @@ pub enum ColumnType { variants: Vec, }, Array(RcOrArc), + Vector(Option), Cidr, Inet, MacAddr, @@ -454,6 +456,12 @@ impl ColumnDef { self } + #[cfg(feature = "postgres-vector")] + pub fn vector(&mut self, size: Option) -> &mut Self { + self.types = Some(ColumnType::Vector(size)); + self + } + /// Set column type as timestamp pub fn timestamp(&mut self) -> &mut Self { self.types = Some(ColumnType::Timestamp); diff --git a/src/value.rs b/src/value.rs index 3c9dbbf28..7ca60fff6 100644 --- a/src/value.rs +++ b/src/value.rs @@ -234,6 +234,19 @@ pub enum Value { #[cfg_attr(docsrs, doc(cfg(feature = "postgres-array")))] Array(ArrayType, Option>>), + #[cfg(feature = "postgres-vector")] + #[cfg_attr(docsrs, doc(cfg(feature = "postgres-vector")))] + Vector( + #[cfg_attr( + feature = "hashable-value", + educe( + Hash(method(hashable_value::hash_vector)), + PartialEq(method(hashable_value::cmp_vector)) + ) + )] + Option>, + ), + #[cfg(feature = "with-ipnetwork")] #[cfg_attr(docsrs, doc(cfg(feature = "with-ipnetwork")))] IpNetwork(Option>), @@ -893,6 +906,45 @@ pub mod with_array { } } +#[cfg(feature = "postgres-vector")] +#[cfg_attr(docsrs, doc(cfg(feature = "postgres-vector")))] +pub mod with_vector { + use super::*; + + impl From for Value { + fn from(x: pgvector::Vector) -> Value { + Value::Vector(Some(Box::new(x))) + } + } + + impl Nullable for pgvector::Vector { + fn null() -> Value { + Value::Vector(None) + } + } + + impl ValueType for pgvector::Vector { + fn try_from(v: Value) -> Result { + match v { + Value::Vector(Some(x)) => Ok(*x), + _ => Err(ValueTypeErr), + } + } + + fn type_name() -> String { + stringify!(Vector).to_owned() + } + + fn array_type() -> ArrayType { + unimplemented!("Vector does not have array type") + } + + fn column_type() -> ColumnType { + ColumnType::Vector(None) + } + } +} + #[allow(unused_macros)] macro_rules! box_to_opt_ref { ( $v: expr ) => { @@ -1392,6 +1444,8 @@ pub fn sea_value_to_json_value(value: &Value) -> Json { Value::Uuid(None) => Json::Null, #[cfg(feature = "postgres-array")] Value::Array(_, None) => Json::Null, + #[cfg(feature = "postgres-vector")] + Value::Vector(None) => Json::Null, #[cfg(feature = "with-ipnetwork")] Value::IpNetwork(None) => Json::Null, #[cfg(feature = "with-mac_address")] @@ -1447,6 +1501,8 @@ pub fn sea_value_to_json_value(value: &Value) -> Json { Value::Array(_, Some(v)) => { Json::Array(v.as_ref().iter().map(sea_value_to_json_value).collect()) } + #[cfg(feature = "postgres-vector")] + Value::Vector(Some(v)) => Json::Array(v.as_slice().iter().map(|&v| v.into()).collect()), #[cfg(feature = "with-ipnetwork")] Value::IpNetwork(Some(_)) => CommonSqlQueryBuilder.value_to_string(value).into(), #[cfg(feature = "with-mac_address")] @@ -1969,6 +2025,41 @@ mod hashable_value { } } + #[cfg(feature = "postgres-vector")] + pub fn hash_vector(v: &Option>, state: &mut H) { + match v { + Some(v) => { + for &value in v.as_slice().iter() { + hash_f32(&Some(value), state); + } + } + None => "null".hash(state), + } + } + + #[cfg(feature = "postgres-vector")] + pub fn cmp_vector( + l: &Option>, + r: &Option>, + ) -> bool { + match (l, r) { + (Some(l), Some(r)) => { + let (l, r) = (l.as_slice(), r.as_slice()); + if l.len() != r.len() { + return false; + } + for (l, r) in l.iter().zip(r.iter()) { + if !cmp_f32(&Some(*l), &Some(*r)) { + return false; + } + } + true + } + (None, None) => true, + _ => false, + } + } + #[test] fn test_hash_value_0() { let hash_set: std::collections::HashSet = [ diff --git a/tests/postgres/query.rs b/tests/postgres/query.rs index a00d57513..2799a2991 100644 --- a/tests/postgres/query.rs +++ b/tests/postgres/query.rs @@ -2110,3 +2110,17 @@ fn test_issue_674_nested_logical_panic() { r#"SELECT "character" FROM "character" WHERE TRUE AND (TRUE AND TRUE AND TRUE)"# ); } + +#[test] +fn test_pgvector_select() { + assert_eq!( + Query::select() + .columns([Char::Character]) + .from(Char::Table) + .and_where( + Expr::col(Char::Character).eq(Expr::val(pgvector::Vector::from(vec![1.0, 2.0]))) + ) + .to_string(PostgresQueryBuilder), + r#"SELECT "character" FROM "character" WHERE "character" = '[1,2]'"# + ); +}