Skip to content

Commit

Permalink
Merge pull request #38 from TopKeyboard/add-parameter
Browse files Browse the repository at this point in the history
PyBinding: Add parameters for query
  • Loading branch information
eric-epsilla authored Aug 18, 2023
2 parents 1452fe2 + e02c16c commit a1165c1
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 17 deletions.
45 changes: 34 additions & 11 deletions engine/bindings/python/interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ static PyObject *create_table(PyObject *self, PyObject *args, PyObject *kwargs)

// Iterate through the list and extract dictionaries
Py_ssize_t list_size = PyList_Size(tableFieldsListPtr);

for (Py_ssize_t i = 0; i < list_size; ++i) {
PyObject *dict_obj = PyList_GetItem(tableFieldsListPtr, i);
vectordb::engine::meta::FieldSchema field;
Expand All @@ -108,7 +107,6 @@ static PyObject *create_table(PyObject *self, PyObject *args, PyObject *kwargs)
PyErr_SetString(PyExc_TypeError, "invalid content: ID is not valid UTF8 string");
return NULL;
}

Py_DECREF(nameKey);
Py_DECREF(nameValue);

Expand Down Expand Up @@ -144,15 +142,14 @@ static PyObject *create_table(PyObject *self, PyObject *args, PyObject *kwargs)
}
Py_DECREF(dimensionsKey);
Py_DECREF(dimensionsValue);
schema.fields_.push_back(field);
}
schema.fields_.push_back(field);
Py_DECREF(dict_obj);
}

Py_DECREF(tableFieldsListPtr);

// TODO: add auto embedding here

auto status = db->CreateTable(db_name, schema);
if (!status.ok()) {
PyErr_SetString(PyExc_Exception, status.message().c_str());
Expand Down Expand Up @@ -216,15 +213,33 @@ static PyObject *insert(PyObject *self, PyObject *args, PyObject *kwargs) {
}

static PyObject *query(PyObject *self, PyObject *args, PyObject *kwargs) {
static const char *keywords[] = {"table_name", "query_field", "query_vector", "limit", NULL};
static const char *keywords[] = {
"table_name",
"query_field",
"query_vector",
"response_fields",
"limit",
"with_distance",
NULL};
const char *tableNamePtr, *queryFieldPtr;
int limit;
PyObject *queryVector;

if (!PyArg_ParseTupleAndKeywords(args, kwargs, "ssOi", (char **)keywords, &tableNamePtr, &queryFieldPtr, &queryVector, &limit)) {
int limit, withDistance;
PyObject *queryVector, *responseFields;

if (!PyArg_ParseTupleAndKeywords(
args,
kwargs,
"ssOOip",
(char **)keywords,
&tableNamePtr,
&queryFieldPtr,
&queryVector,
&responseFields,
&limit,
&withDistance)) {
return NULL;
}
Py_XINCREF(queryVector);
Py_XINCREF(responseFields);
auto queryFields = std::vector<std::string>();

Py_ssize_t queryVectorSize = PyList_Size(queryVector);
Expand All @@ -236,6 +251,14 @@ static PyObject *query(PyObject *self, PyObject *args, PyObject *kwargs) {
vectorArr[i] = PyFloat_AsDouble(elem);
Py_XDECREF(elem);
}
Py_ssize_t responseFieldsSize = PyList_Size(responseFields);
for (Py_ssize_t i = 0; i < responseFieldsSize; ++i) {
PyObject *elem = PyObject_Str(PyList_GetItem(responseFields, i));
Py_XINCREF(elem);
std::string field = PyUnicode_AsUTF8(elem);
Py_XDECREF(elem);
queryFields.push_back(field);
}
Py_XDECREF(queryVector);
auto result = vectordb::Json();
auto status = db->Search(
Expand All @@ -247,8 +270,8 @@ static PyObject *query(PyObject *self, PyObject *args, PyObject *kwargs) {
vectorArr.get(),
limit,
result,
// TODO: make it variable
true);
withDistance);

if (!status.ok()) {
PyErr_SetString(PyExc_Exception, status.message().c_str());
return NULL;
Expand Down
5 changes: 2 additions & 3 deletions engine/db/catalog/basic_meta_impl.cpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
// basic_meta_impl.cpp
#include "db/catalog/basic_meta_impl.hpp"

#include <iostream>

#include "utils/common_util.hpp"
#include "utils/json.hpp"

#include <iostream>

namespace vectordb {
namespace engine {
namespace meta {
Expand Down Expand Up @@ -150,7 +150,6 @@ Status SaveDBToFile(const DatabaseSchema& db, const std::string& file_path) {

// Write the Json object to a string
std::string json_string = json.DumpToString();

// Write the string to the file
return server::CommonUtil::AtomicWriteToFile(file_path, json_string);
}
Expand Down
7 changes: 6 additions & 1 deletion engine/test.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
export PYTHONPATH=./build/
export DB_PATH=/tmp/db2
rm -rf "$DB_PATH"
python3 test/bindings/python/concurrent_test.py
echo $1
if [ "$1" == "--single-thread" ]; then
python3 test/bindings/python/test.py
else
python3 test/bindings/python/concurrent_test.py
fi
4 changes: 3 additions & 1 deletion engine/test/bindings/python/concurrent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@ def run_task():
(code, response) = epsilla.query(
table_name="MyTable",
query_field="Embedding",
response_fields=["Doc"],
query_vector=[0.35, 0.55, 0.47, 0.94],
limit=2
limit=2,
with_distance=True
)
print(code, response)

Expand Down
4 changes: 3 additions & 1 deletion engine/test/bindings/python/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@
(code, response) = epsilla.query(
table_name="MyTable",
query_field="Embedding",
response_fields=["ID", "Doc", "Embedding"],
query_vector=[0.35, 0.55, 0.47, 0.94],
limit=2
limit=2,
with_distance=True
)
print(code, response)

Expand Down

0 comments on commit a1165c1

Please sign in to comment.