Skip to content

Commit

Permalink
Add pgvector type (#774)
Browse files Browse the repository at this point in the history
* feat: Add pgvector type

* feat: Update Workflow

* fix: Add feature flag to sea-query-rustqlite

* fix: Add hashing

* fix: Dependencies

* fix: Rusqlite match value missing

* fix: Wrong feature flag

* feat: PgVector diesel support

* fix: Add feature gate for column type

* feat: Add binary vector operations

* remove: Feature gate on colum type

* feat: Add vector size

* feat add vector to column def

* fix: Fix version of pgvector to be compatible with sqlx <0.7.5
  • Loading branch information
28Smiles authored Aug 9, 2024
1 parent d898d77 commit b83eaa9
Show file tree
Hide file tree
Showing 25 changed files with 194 additions and 5 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/diesel.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ jobs:
- uses: actions/checkout@v3
- uses: dtolnay/rust-toolchain@stable
- run: cargo update --manifest-path sea-query-diesel/Cargo.toml --workspace -p bigdecimal:0.4.5 --precise 0.3.1
- run: cargo build --manifest-path sea-query-diesel/Cargo.toml --workspace --features postgres,sqlite,mysql --features=with-chrono,with-json,with-rust_decimal,with-bigdecimal,with-uuid,with-time,with-ipnetwork,with-mac_address,postgres-array
- run: cargo build --manifest-path sea-query-diesel/Cargo.toml --workspace --features postgres,sqlite,mysql --features=with-chrono,with-json,with-rust_decimal,with-bigdecimal,with-uuid,with-time,with-ipnetwork,with-mac_address,postgres-array,postgres-vector
- run: cargo build --manifest-path sea-query-diesel/Cargo.toml --workspace --features postgres,sqlite,mysql --features=with-chrono
- run: cargo build --manifest-path sea-query-diesel/Cargo.toml --workspace --features postgres,sqlite,mysql --features=with-json
- run: cargo build --manifest-path sea-query-diesel/Cargo.toml --workspace --features postgres,sqlite,mysql --features=with-rust_decimal
Expand All @@ -57,6 +57,7 @@ jobs:
- run: cargo build --manifest-path sea-query-diesel/Cargo.toml --workspace --features postgres,sqlite,mysql --features=with-ipnetwork
- run: cargo build --manifest-path sea-query-diesel/Cargo.toml --workspace --features postgres,sqlite,mysql --features=with-mac_address
- run: cargo build --manifest-path sea-query-diesel/Cargo.toml --workspace --features postgres,sqlite,mysql --features=postgres-array
- run: cargo build --manifest-path sea-query-diesel/Cargo.toml --workspace --features postgres,sqlite,mysql --features=postgres-vector

sqlite:
name: SQLite
Expand Down
12 changes: 8 additions & 4 deletions .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ jobs:
toolchain: stable
components: clippy
- run: cargo clippy --features=all-features --workspace -- -D warnings
- run: cargo clippy --manifest-path sea-query-binder/Cargo.toml --workspace --features runtime-async-std-rustls --features=with-chrono,with-json,with-rust_decimal,with-bigdecimal,with-uuid,with-time,with-ipnetwork,with-mac_address,postgres-array -- -D warnings
- run: cargo clippy --manifest-path sea-query-binder/Cargo.toml --workspace --features runtime-async-std-rustls --features=with-chrono,with-json,with-rust_decimal,with-bigdecimal,with-uuid,with-time,with-ipnetwork,with-mac_address,postgres-array,postgres-vector -- -D warnings
- run: cargo clippy --manifest-path sea-query-rusqlite/Cargo.toml --all-features --workspace -- -D warnings
- run: cargo clippy --manifest-path sea-query-postgres/Cargo.toml --all-features --workspace -- -D warnings

Expand All @@ -78,6 +78,7 @@ jobs:
- run: cargo build --workspace --features=with-ipnetwork
- run: cargo build --workspace --features=with-mac_address
- run: cargo build --workspace --features=postgres-array
- run: cargo build --workspace --features=postgres-vector
- run: cargo build --workspace --features=thread-safe

binder-build:
Expand All @@ -91,7 +92,7 @@ jobs:
steps:
- uses: actions/checkout@v3
- uses: dtolnay/rust-toolchain@stable
- run: cargo build --manifest-path sea-query-binder/Cargo.toml --workspace --features sqlx-postgres,sqlx-sqlite,sqlx-any,sqlx-mysql --features=runtime-${{ matrix.runtime }}-${{ matrix.tls }} --features=with-chrono,with-json,with-rust_decimal,with-bigdecimal,with-uuid,with-time,with-ipnetwork,with-mac_address,postgres-array
- run: cargo build --manifest-path sea-query-binder/Cargo.toml --workspace --features sqlx-postgres,sqlx-sqlite,sqlx-any,sqlx-mysql --features=runtime-${{ matrix.runtime }}-${{ matrix.tls }} --features=with-chrono,with-json,with-rust_decimal,with-bigdecimal,with-uuid,with-time,with-ipnetwork,with-mac_address,postgres-array,postgres-vector
- run: cargo build --manifest-path sea-query-binder/Cargo.toml --workspace --features sqlx-postgres,sqlx-sqlite,sqlx-any,sqlx-mysql --features=runtime-${{ matrix.runtime }}-${{ matrix.tls }} --features=with-chrono
- run: cargo build --manifest-path sea-query-binder/Cargo.toml --workspace --features sqlx-postgres,sqlx-sqlite,sqlx-any,sqlx-mysql --features=runtime-${{ matrix.runtime }}-${{ matrix.tls }} --features=with-json
- run: cargo build --manifest-path sea-query-binder/Cargo.toml --workspace --features sqlx-postgres,sqlx-sqlite,sqlx-any,sqlx-mysql --features=runtime-${{ matrix.runtime }}-${{ matrix.tls }} --features=with-rust_decimal
Expand All @@ -101,14 +102,15 @@ jobs:
- run: cargo build --manifest-path sea-query-binder/Cargo.toml --workspace --features sqlx-postgres,sqlx-sqlite,sqlx-any,sqlx-mysql --features=runtime-${{ matrix.runtime }}-${{ matrix.tls }} --features=with-ipnetwork
- run: cargo build --manifest-path sea-query-binder/Cargo.toml --workspace --features sqlx-postgres,sqlx-sqlite,sqlx-any,sqlx-mysql --features=runtime-${{ matrix.runtime }}-${{ matrix.tls }} --features=with-mac_address
- run: cargo build --manifest-path sea-query-binder/Cargo.toml --workspace --features sqlx-postgres,sqlx-sqlite,sqlx-any,sqlx-mysql --features=runtime-${{ matrix.runtime }}-${{ matrix.tls }} --features=postgres-array
- run: cargo build --manifest-path sea-query-binder/Cargo.toml --workspace --features sqlx-postgres,sqlx-sqlite,sqlx-any,sqlx-mysql --features=runtime-${{ matrix.runtime }}-${{ matrix.tls }} --features=postgres-vector

rusqlite-build:
name: Build `sea-query-rusqlite`
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: dtolnay/rust-toolchain@stable
- run: cargo build --manifest-path sea-query-rusqlite/Cargo.toml --workspace --features=with-chrono,with-json,with-rust_decimal,with-bigdecimal,with-uuid,with-time,with-ipnetwork,with-mac_address,postgres-array
- run: cargo build --manifest-path sea-query-rusqlite/Cargo.toml --workspace --features=with-chrono,with-json,with-rust_decimal,with-bigdecimal,with-uuid,with-time,with-ipnetwork,with-mac_address,postgres-array,postgres-vector
- run: cargo build --manifest-path sea-query-rusqlite/Cargo.toml --workspace --features=with-chrono
- run: cargo build --manifest-path sea-query-rusqlite/Cargo.toml --workspace --features=with-json
- run: cargo build --manifest-path sea-query-rusqlite/Cargo.toml --workspace --features=with-rust_decimal
Expand All @@ -118,14 +120,15 @@ jobs:
- run: cargo build --manifest-path sea-query-rusqlite/Cargo.toml --workspace --features=with-ipnetwork
- run: cargo build --manifest-path sea-query-rusqlite/Cargo.toml --workspace --features=with-mac_address
- run: cargo build --manifest-path sea-query-rusqlite/Cargo.toml --workspace --features=postgres-array
- run: cargo build --manifest-path sea-query-rusqlite/Cargo.toml --workspace --features=postgres-vector

postgres-build:
name: Build `sea-query-postgres`
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: dtolnay/rust-toolchain@stable
- run: cargo build --manifest-path sea-query-postgres/Cargo.toml --workspace --features=with-chrono,with-json,with-rust_decimal,with-bigdecimal,with-uuid,with-time,with-ipnetwork,with-mac_address,postgres-array
- run: cargo build --manifest-path sea-query-postgres/Cargo.toml --workspace --features=with-chrono,with-json,with-rust_decimal,with-bigdecimal,with-uuid,with-time,with-ipnetwork,with-mac_address,postgres-array,postgres-vector
- run: cargo build --manifest-path sea-query-postgres/Cargo.toml --workspace --features=with-chrono
- run: cargo build --manifest-path sea-query-postgres/Cargo.toml --workspace --features=with-json
- run: cargo build --manifest-path sea-query-postgres/Cargo.toml --workspace --features=with-rust_decimal
Expand All @@ -135,6 +138,7 @@ jobs:
- run: cargo build --manifest-path sea-query-postgres/Cargo.toml --workspace --features=with-ipnetwork
- run: cargo build --manifest-path sea-query-postgres/Cargo.toml --workspace --features=with-mac_address
- run: cargo build --manifest-path sea-query-postgres/Cargo.toml --workspace --features=postgres-array
- run: cargo build --manifest-path sea-query-postgres/Cargo.toml --workspace --features=postgres-vector

test:
name: Unit Test
Expand Down
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ serde_json = { version = "1", default-features = false, optional = true, feature
educe = { version = "=0.5.11", default-features = false, optional = true, features = ["Hash", "PartialEq", "Eq"] }
chrono = { version = "0.4.27", default-features = false, optional = true, features = ["clock"] }
postgres-types = { version = "0", default-features = false, optional = true }
pgvector = { version = "<0.4", default-features = false, optional = true }
rust_decimal = { version = "1", default-features = false, optional = true }
bigdecimal = { version = "0.3", default-features = false, optional = true }
uuid = { version = "1", default-features = false, optional = true }
Expand All @@ -58,6 +59,7 @@ derive = ["sea-query-derive"]
attr = ["sea-query-attr"]
hashable-value = ["educe", "ordered-float"]
postgres-array = []
postgres-vector = ["pgvector"]
postgres-interval = []
thread-safe = []
with-chrono = ["chrono"]
Expand All @@ -80,6 +82,7 @@ all-features = [
"all-types",
] # everything except option-*
all-types = [
"postgres-vector",
"postgres-array",
"postgres-interval",
"with-chrono",
Expand Down
2 changes: 2 additions & 0 deletions sea-query-binder/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ uuid = { version = "1", default-features = false, optional = true }
time = { version = "0.3.36", default-features = false, optional = true, features = ["macros", "formatting"] }
ipnetwork = { version = "0.20", default-features = false, optional = true }
mac_address = { version = "1.1", default-features = false, optional = true }
pgvector = { version = "<0.4", default-features = false, optional = true }

[features]
sqlx-mysql = ["sqlx/mysql"]
Expand All @@ -42,6 +43,7 @@ with-time = ["sqlx?/time", "sea-query/with-time", "time"]
with-ipnetwork = ["sqlx?/ipnetwork", "sea-query/with-ipnetwork", "ipnetwork"]
with-mac_address = ["sqlx?/mac_address", "sea-query/with-mac_address", "mac_address"]
postgres-array = ["sea-query/postgres-array"]
postgres-vector = ["sea-query/postgres-vector", "pgvector/sqlx"]
runtime-async-std = ["sqlx?/runtime-async-std"]
runtime-async-std-native-tls = ["sqlx?/runtime-async-std-native-tls"]
runtime-async-std-rustls = ["sqlx?/runtime-async-std-rustls", ]
Expand Down
4 changes: 4 additions & 0 deletions sea-query-binder/src/sqlx_any.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,10 @@ impl<'q> sqlx::IntoArguments<'q, sqlx::any::Any> for SqlxValues {
Value::Array(_, _) => {
panic!("SQLx doesn't support array arguments for Any");
}
#[cfg(feature = "postgres-vector")]
Value::Vector(_) => {
panic!("SQLx doesn't support vector arguments for Any");
}
}
}
args
Expand Down
4 changes: 4 additions & 0 deletions sea-query-binder/src/sqlx_mysql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@ impl sqlx::IntoArguments<'_, sqlx::mysql::MySql> for SqlxValues {
Value::Array(_, _) => {
panic!("Mysql doesn't support array arguments");
}
#[cfg(feature = "postgres-vector")]
Value::Vector(_) => {
panic!("Mysql doesn't support vector arguments");
}
#[cfg(feature = "with-ipnetwork")]
Value::IpNetwork(_) => {
panic!("Mysql doesn't support IpNetwork arguments");
Expand Down
4 changes: 4 additions & 0 deletions sea-query-binder/src/sqlx_postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,10 @@ impl sqlx::IntoArguments<'_, sqlx::postgres::Postgres> for SqlxValues {
let _ = args.add(value);
}
},
#[cfg(feature = "postgres-vector")]
Value::Vector(v) => {
let _ = args.add(v.as_deref());
}
}
}
args
Expand Down
4 changes: 4 additions & 0 deletions sea-query-binder/src/sqlx_sqlite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,10 @@ impl<'q> sqlx::IntoArguments<'q, sqlx::sqlite::Sqlite> for SqlxValues {
Value::Array(_, _) => {
panic!("Sqlite doesn't support array arguments");
}
#[cfg(feature = "postgres-vector")]
Value::Vector(_) => {
panic!("Sqlite doesn't support vector arguments");
}
}
}
args
Expand Down
2 changes: 2 additions & 0 deletions sea-query-diesel/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ uuid = { version = "1", default-features = false, optional = true }
serde_json = { version = "1", default-features = false, optional = true }
ipnetwork = { version = "0.20", default-features = false, optional = true }
mac_address = { version = "1.1", default-features = false, optional = true }
pgvector = { version = "<0.4", default-features = false, optional = true }

[features]
default = []
Expand Down Expand Up @@ -64,3 +65,4 @@ with-ipnetwork = [
]
with-mac_address = ["sea-query/with-mac_address", "mac_address"]
postgres-array = ["sea-query/postgres-array"]
postgres-vector = ["sea-query/postgres-vector", "pgvector/diesel"]
2 changes: 2 additions & 0 deletions sea-query-diesel/src/backend/mysql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ impl TransformValue for Mysql {
Value::MacAddress(_) => bail!("Mysql doesn't support MacAddress arguments"),
#[cfg(feature = "postgres-array")]
Value::Array(_, _) => bail!("Mysql doesn't support array arguments"),
#[cfg(feature = "postgres-vector")]
Value::Vector(_) => bail!("Mysql doesn't support vector arguments"),
};
Ok(transformed)
}
Expand Down
2 changes: 2 additions & 0 deletions sea-query-diesel/src/backend/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,8 @@ impl TransformValue for Pg {
)
}
},
#[cfg(feature = "postgres-vector")]
Value::Vector(v) => build!(pgvector::sql_types::Vector, v.map(|v| *v)),
};
Ok(transformed)
}
Expand Down
2 changes: 2 additions & 0 deletions sea-query-diesel/src/backend/sqlite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ impl TransformValue for Sqlite {
Value::MacAddress(_) => bail!("Sqlite doesn't support MacAddress arguments"),
#[cfg(feature = "postgres-array")]
Value::Array(_, _) => bail!("Sqlite doesn't support array arguments"),
#[cfg(feature = "postgres-vector")]
Value::Vector(_) => bail!("Sqlite doesn't support vector arguments"),
};
Ok(transformed)
}
Expand Down
2 changes: 2 additions & 0 deletions sea-query-postgres/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ rust-version = "1.60"
[dependencies]
sea-query = { version = "0.31.0", path = "..", default-features = false }
postgres-types = { version = "0.2", default-features = false }
pgvector = { version = "<0.4", default-features = false, optional = true }
bytes = { version = "1", default-features = false }
rust_decimal = { version = "1", default-features = false, optional = true }
bigdecimal = { version = "0.3", default-features = false, optional = true }
Expand All @@ -35,5 +36,6 @@ with-bigdecimal = ["sea-query/with-bigdecimal", "bigdecimal"]
with-uuid = ["postgres-types/with-uuid-1", "sea-query/with-uuid"]
with-time = ["postgres-types/with-time-0_3", "sea-query/with-time"]
postgres-array = ["postgres-types/array-impls", "sea-query/postgres-array"]
postgres-vector = ["sea-query/postgres-vector", "pgvector/postgres"]
with-ipnetwork = ["postgres-types/with-cidr-0_2", "sea-query/with-ipnetwork", "ipnetwork", "cidr"]
with-mac_address = ["postgres-types/with-eui48-1", "sea-query/with-mac_address", "mac_address", "eui48"]
4 changes: 4 additions & 0 deletions sea-query-postgres/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,10 @@ impl ToSql for PostgresValue {
.to_sql(ty, out),
#[cfg(feature = "postgres-array")]
Value::Array(_, None) => Ok(IsNull::Yes),
#[cfg(feature = "postgres-vector")]
Value::Vector(Some(v)) => v.to_sql(ty, out),
#[cfg(feature = "postgres-vector")]
Value::Vector(None) => Ok(IsNull::Yes),
#[cfg(feature = "with-ipnetwork")]
Value::IpNetwork(v) => {
use cidr::IpCidr;
Expand Down
1 change: 1 addition & 0 deletions sea-query-rusqlite/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,4 @@ with-time = ["rusqlite/time", "sea-query/with-time"]
with-ipnetwork = ["sea-query/with-ipnetwork"]
with-mac_address = ["sea-query/with-mac_address"]
postgres-array = ["sea-query/postgres-array"]
postgres-vector = ["sea-query/postgres-vector"]
4 changes: 4 additions & 0 deletions sea-query-rusqlite/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,10 @@ impl ToSql for RusqliteValue {
Value::Array(_, _) => {
panic!("Rusqlite doesn't support Array arguments");
}
#[cfg(feature = "postgres-vector")]
Value::Vector(_) => {
panic!("Rusqlite doesn't support Vector arguments");
}
}
}
}
1 change: 1 addition & 0 deletions src/backend/mysql/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ impl TableBuilder for MysqlQueryBuilder {
.join("', '")
),
ColumnType::Array(_) => unimplemented!("Array is not available in MySQL."),
ColumnType::Vector(_) => unimplemented!("Vector is not available in MySQL."),
ColumnType::Cidr => unimplemented!("Cidr is not available in MySQL."),
ColumnType::Inet => unimplemented!("Inet is not available in MySQL."),
ColumnType::MacAddr => unimplemented!("MacAddr is not available in MySQL."),
Expand Down
6 changes: 6 additions & 0 deletions src/backend/postgres/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,12 @@ impl QueryBuilder for PostgresQueryBuilder {
PgBinOper::CastJsonField => "->>",
PgBinOper::Regex => "~",
PgBinOper::RegexCaseInsensitive => "~*",
#[cfg(feature = "postgres-vector")]
PgBinOper::EuclideanDistance => "<->",
#[cfg(feature = "postgres-vector")]
PgBinOper::NegativeInnerProduct => "<#>",
#[cfg(feature = "postgres-vector")]
PgBinOper::CosineDistance => "<=>",
}
)
.unwrap(),
Expand Down
4 changes: 4 additions & 0 deletions src/backend/postgres/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ impl TableBuilder for PostgresQueryBuilder {
self.prepare_column_type(elem_type, &mut sql);
format!("{sql}[]")
}
ColumnType::Vector(size) => match size {
Some(size) => format!("vector({size})"),
None => "vector".into(),
},
ColumnType::Custom(iden) => iden.to_string(),
ColumnType::Enum { name, .. } => name.to_string(),
ColumnType::Cidr => "cidr".into(),
Expand Down
13 changes: 13 additions & 0 deletions src/backend/query_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1045,6 +1045,8 @@ pub trait QueryBuilder:
Value::MacAddress(None) => write!(s, "NULL").unwrap(),
#[cfg(feature = "postgres-array")]
Value::Array(_, None) => write!(s, "NULL").unwrap(),
#[cfg(feature = "postgres-vector")]
Value::Vector(None) => write!(s, "NULL").unwrap(),
Value::Bool(Some(b)) => write!(s, "{}", if *b { "TRUE" } else { "FALSE" }).unwrap(),
Value::TinyInt(Some(v)) => write!(s, "{v}").unwrap(),
Value::SmallInt(Some(v)) => write!(s, "{v}").unwrap(),
Expand Down Expand Up @@ -1118,6 +1120,17 @@ pub trait QueryBuilder:
.join(",")
)
.unwrap(),
#[cfg(feature = "postgres-vector")]
Value::Vector(Some(v)) => {
write!(s, "'[").unwrap();
for (i, &element) in v.as_slice().iter().enumerate() {
if i != 0 {
write!(s, ",").unwrap();
}
write!(s, "{element}").unwrap();
}
write!(s, "]'").unwrap();
}
#[cfg(feature = "with-ipnetwork")]
Value::IpNetwork(Some(v)) => write!(s, "'{v}'").unwrap(),
#[cfg(feature = "with-mac_address")]
Expand Down
1 change: 1 addition & 0 deletions src/backend/sqlite/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ impl SqliteQueryBuilder {
ColumnType::Custom(iden) => iden.to_string(),
ColumnType::Enum { .. } => "enum_text".into(),
ColumnType::Array(_) => unimplemented!("Array is not available in Sqlite."),
ColumnType::Vector(_) => unimplemented!("Vector is not available in Sqlite."),
ColumnType::Cidr => unimplemented!("Cidr is not available in Sqlite."),
ColumnType::Inet => unimplemented!("Inet is not available in Sqlite."),
ColumnType::MacAddr => unimplemented!("MacAddr is not available in Sqlite."),
Expand Down
Loading

0 comments on commit b83eaa9

Please sign in to comment.