Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Casts #19

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions sql/lanterndb.sql
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,81 @@ BEGIN
END;
$$
LANGUAGE plpgsql;
-- -- CREATE TYPEs vec*

-- Function that are generic over the family of vec types

CREATE FUNCTION ldb_generic_vec_typmod_in(cstring[]) RETURNS integer
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;

-- 8-byte unit-vector (i.e. vector with elements in range [-1, 1])
CREATE TYPE uvec8;

CREATE FUNCTION ldb_uvec8_in(cstring, oid, integer) RETURNS uvec8 AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE FUNCTION ldb_uvec8_out(uvec8) RETURNS cstring AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE FUNCTION ldb_uvec8_recv(internal, oid, integer) RETURNS uvec8 AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE FUNCTION ldb_uvec8_send(uvec8) RETURNS bytea AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;

CREATE TYPE uvec8 (
INPUT = ldb_uvec8_in,
OUTPUT = ldb_uvec8_out,
RECEIVE = ldb_uvec8_recv,
SEND = ldb_uvec8_send,
TYPMOD_IN = ldb_generic_vec_typmod_in,
STORAGE = extended
);

CREATE TYPE vec8;

CREATE FUNCTION ldb_vec8_in(cstring, oid, integer) RETURNS vec8 AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE FUNCTION ldb_vec8_out(vec8) RETURNS cstring AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE FUNCTION ldb_vec8_recv(internal, oid, integer) RETURNS vec8 AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE FUNCTION ldb_vec8_send(vec8) RETURNS bytea AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;

CREATE TYPE vec8 (
INPUT = ldb_vec8_in,
OUTPUT = ldb_vec8_out,
RECEIVE = ldb_vec8_recv,
SEND = ldb_vec8_send,
TYPMOD_IN = ldb_generic_vec_typmod_in,
STORAGE = extended
);

-- cast functions

CREATE FUNCTION ldb_cast_uvec8_uvec8(uvec8, integer, boolean) RETURNS uvec8
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE FUNCTION ldb_cast_array_uvec8(integer[], integer, boolean) RETURNS uvec8
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE FUNCTION ldb_cast_array_uvec8(real[], integer, boolean) RETURNS uvec8
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE FUNCTION ldb_cast_array_uvec8(double precision[], integer, boolean) RETURNS uvec8
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE FUNCTION ldb_cast_array_uvec8(numeric[], integer, boolean) RETURNS uvec8
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;

CREATE FUNCTION ldb_cast_vec_real(uvec8, integer, boolean) RETURNS real[]
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;

-- casts

CREATE CAST (uvec8 AS uvec8)
WITH FUNCTION ldb_cast_uvec8_uvec8(uvec8, integer, boolean) AS IMPLICIT;

CREATE CAST (integer[] AS uvec8)
WITH FUNCTION ldb_cast_array_uvec8(integer[], integer, boolean) AS ASSIGNMENT;

CREATE CAST (real[] AS uvec8)
WITH FUNCTION ldb_cast_array_uvec8(real[], integer, boolean) AS ASSIGNMENT;

CREATE CAST (double precision[] AS uvec8)
WITH FUNCTION ldb_cast_array_uvec8(double precision[], integer, boolean) AS ASSIGNMENT;

CREATE CAST (numeric[] AS uvec8)
WITH FUNCTION ldb_cast_array_uvec8(numeric[], integer, boolean) AS ASSIGNMENT;

CREATE CAST (uvec8 AS real[])
WITH FUNCTION ldb_cast_vec_real(uvec8, integer, boolean) AS ASSIGNMENT;



2 changes: 1 addition & 1 deletion src/hnsw/build.c
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
#include "bench.h"
#include "external_index.h"
#include "hnsw.h"
#include "pgvector_vector.h"
#include "usearch.h"
#include "utils.h"
#include "vector.h"

#if PG_VERSION_NUM >= 140000
#include "utils/backend_progress.h"
Expand Down
2 changes: 1 addition & 1 deletion src/hnsw/insert.c
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
#include "external_index.h"
#include "hnsw.h"
#include "options.h"
#include "pgvector_vector.h"
#include "usearch.h"
#include "utils.h"
#include "vector.h"

/*
* Context delete callback for insert context
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion src/hnsw/scan.c
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#include "external_index.h"
#include "hnsw.h"
#include "options.h"
#include "vector.h"
#include "pgvector_vector.h"

PG_MODULE_MAGIC;

Expand Down
55 changes: 55 additions & 0 deletions src/vec_type.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#ifndef LDB_VEC_H
#define LDB_VEC_H
#include <postgres.h>

typedef struct
{
int32 vl_len_; /* varlena header (do not touch directly!) */
uint16 dim; /* number of dimensions */
uint16 elem_type;
char data[ FLEXIBLE_ARRAY_MEMBER ];
} LDBVec;

static inline int VecScalarSize(usearch_scalar_kind_t s)
{
switch(s) {
// clang-format off
case usearch_scalar_f64_k: return 8;
case usearch_scalar_f32_k: return 4;
case usearch_scalar_f16_k: return 2;
case usearch_scalar_f8_k: return 1;
case usearch_scalar_b1_k: return 1;
// clang-format on
}
assert(false);
}

static inline LDBVec *NewLDBVec(int dim, int elem_type)
{
LDBVec *result;
int size;

size = sizeof(LDBVec) + dim * VecScalarSize(elem_type);
result = (LDBVec *)palloc0(size);
SET_VARSIZE(result, size);
result->dim = dim;
result->elem_type = elem_type;

return result;
}

/* Confined by uint16 in LDBVec structure */
#define LDB_VEC_MAX_DIM ((1 << 16) - 1)
/*
* Returns a pointer to the actual array data.
*/
#define LDBVEC_DATA_SIZE(a) (((a->dim)) * (VecScalarSize(a->elem_type)))

/*
* Returns a pointer to the actual array data.
*/
#define LDBVEC_DATA_PTR(a) (((void *)(a->data)))

#define DatumGetLDBVec(x) ((LDBVec *)PG_DETOAST_DATUM(x))

#endif // LDB_VEC_H
Loading
Loading