diff --git a/README.md b/README.md index faf3ea7..48b1d37 100644 --- a/README.md +++ b/README.md @@ -61,8 +61,8 @@ documents = [ upload.documents( database="ducksearch.duckdb", - key="id", # unique document identifier - fields=["title", "style", "date", "popularity"], # list of fields to index + key="id", # Unique document identifier + fields=["title", "style"], # List of fields to use for search. documents=documents, dtypes={ "date": "DATE", @@ -73,7 +73,7 @@ upload.documents( ## Search -`search.documents` returns a list of list of documents ordered by relevance. We can control the number of documents to return using the `top_k` parameter. The following example demonstrates how to search for documents with the queries "punk" and "california" while filtering the results to include only documents with a date after 1970 and a popularity score greater than 8. +`search.documents` returns a list of list of documents ordered by relevance. We can control the number of documents to return using the `top_k` parameter. The following example demonstrates how to search for documents with the queries "punk" and "california" while filtering the results to include only documents with a date after 1970 and a popularity score greater than 8. We will order the results by a weighted sum of the BM25 score and the popularity score provided in the document. ```python from ducksearch import search @@ -83,6 +83,7 @@ search.documents( queries=["punk", "california"], top_k=10, filters="YEAR(date) >= 1970 AND popularity > 8", + order_by="0.8 * score + 0.2 * popularity DESC", ) ``` @@ -113,6 +114,8 @@ search.documents( Filters are SQL expressions that are applied to the search results. We can use every filtering function DuckDB provides such as [date functions](https://duckdb.org/docs/sql/functions/date). +Both `filters` and `order_by` parameters are optional. If not provided, the results are ordered by BM25 relevance and no filters are applied. + ## Delete and update index We can delete documents and update the BM25 weights accordingly using the `delete.documents` function. @@ -132,8 +135,7 @@ To update the index, we should first delete the documents and then upload the up ### HuggingFace -The `upload.documents` function can also index HuggingFace datasets directly from the url. -The following example demonstrates how to index the FineWeb dataset from HuggingFace: +The `upload.documents` function can also index HuggingFace datasets directly from the url. The following example demonstrates how to index the FineWeb dataset from HuggingFace. We will use the fields "text" and "url" for search. We will also specify the data types for the "date", "token_count", and "language_score" fields to be able to filter the results. ```python from ducksearch import upload @@ -141,53 +143,68 @@ from ducksearch import upload upload.documents( database="fineweb.duckdb", key="id", - fields=["text", "url", "date", "language", "token_count", "language_score"], + fields=["text", "url"], documents="https://huggingface.co/datasets/HuggingFaceFW/fineweb/resolve/main/sample/10BT/000_00000.parquet", dtypes={ "date": "DATE", "token_count": "INT", "language_score": "FLOAT", }, - limit=1000, # demonstrate with a small dataset + limit=3000, # demonstrate with a small dataset ) ``` -We can then search the FineWeb dataset with the `search.documents` function: +We can then search the FineWeb dataset with the `search.documents` function. We order the results by BM25 score and then date. ```python from ducksearch import search search.documents( database="fineweb.duckdb", - queries="earth science", + queries=["earth science"], top_k=2, + order_by="score DESC, date DESC", ) ``` ```python [ - { - "id": "", - "text": "Earth Science Tutors in Rowland ...", - "date": Timestamp("2017-08-19 00:00:00"), - "language": "en", - "token_count": 313, - "language_score": 0.8718525171279907, - "score": 1.1588547229766846, - }, - { - "score": 1.6727683544158936, - "id": "", - "text": "The existing atmosphere surrounding the earth contains ...", - "url": "http://www.accuracyingenesis.com/atmargon.html", - "date": Timestamp("2015-04-02 00:00:00"), - "language": "en", - "token_count": 1348, - "language_score": 0.9564403295516968, - }, + [ + { + "id": "", + "text": "Earth Science Tutors in Rowland...", + "id_1": "", + "dump": "CC-MAIN-2017-34", + "url": "http://rowland.universitytutor.com/rowland_earth-science-tutoring", + "date": Timestamp("2017-08-19 00:00:00"), + "file_path": "s3://commoncrawl/crawl-data/CC-MAIN-2017-34/segments/1502886105304.35/warc/CC-MAIN-20170819051034-20170819071034-00240.warc.gz", + "language": "en", + "language_score": 0.8718525171279907, + "token_count": 313, + "bm25id": 523, + "score": 2.3761106729507446, + }, + { + "id": "", + "text": "- Geomagnetic field....", + "id_1": "", + "dump": "CC-MAIN-2022-21", + "url": "https://www.imperial.ac.uk/people/adrian.muxworthy/?respub-action=citation.html&id=1149861&noscript=noscript", + "date": Timestamp("2022-05-20 00:00:00"), + "file_path": "s3://commoncrawl/crawl-data/CC-MAIN-2022-21/segments/1652662530553.34/warc/CC-MAIN-20220519235259-20220520025259-00601.warc.gz", + "language": "en", + "language_score": 0.8225595951080322, + "token_count": 517, + "bm25id": 4783, + "score": 2.3569871187210083, + }, + ] ] + ``` +Note: by default, results are ordered by BM25 relevance. + ## Tables Ducksearch creates two distinct schemas: `bm25_tables`, `bm25_documents`. diff --git a/docs/api/decorators/connect-to-duckdb.md b/docs/api/decorators/connect-to-duckdb.md index 1773bf2..1aaa345 100644 --- a/docs/api/decorators/connect-to-duckdb.md +++ b/docs/api/decorators/connect-to-duckdb.md @@ -1,6 +1,6 @@ # connect_to_duckdb -Establish a connection to the DuckDB database. +Establish a connection to the DuckDB database. Retry connecting if an error occurs. @@ -18,6 +18,16 @@ Establish a connection to the DuckDB database. Optional configuration settings for the DuckDB connection. +- **max_retry** (*int*) – defaults to `20` + + The maximum number of times to retry connecting to DuckDB. + +- **sleep_time** (*float*) – defaults to `0.1` + + The time to sleep between retries. + +- **kwargs** + diff --git a/docs/api/hf/insert-documents.md b/docs/api/hf/insert-documents.md index 7c305f2..f0fb146 100644 --- a/docs/api/hf/insert-documents.md +++ b/docs/api/hf/insert-documents.md @@ -18,10 +18,6 @@ Insert documents from a Hugging Face dataset into DuckDB. The key field that uniquely identifies each document (e.g., 'query_id'). -- **fields** (*str | list[str]*) - - A list of fields to be inserted from the dataset. If a single field is provided as a string, it will be converted to a list. - - **url** (*str*) The URL of the Hugging Face dataset in Parquet format. @@ -32,6 +28,8 @@ Insert documents from a Hugging Face dataset into DuckDB. - **limit** (*int | None*) – defaults to `None` +- **dtypes** (*dict | None*) – defaults to `None` + ## Examples @@ -41,13 +39,24 @@ Insert documents from a Hugging Face dataset into DuckDB. >>> upload.documents( ... database="test.duckdb", -... documents="hf://datasets/lightonai/lighton-ms-marco-mini/train.parquet", -... fields=["document_ids", "scores"], +... documents="hf://datasets/lightonai/lighton-ms-marco-mini/queries.parquet", ... key="query_id", +... fields=["query_id", "text"], ... ) | Table | Size | |----------------|------| | documents | 19 | | bm25_documents | 19 | + +>>> upload.documents( +... database="test.duckdb", +... documents="hf://datasets/lightonai/lighton-ms-marco-mini/documents.parquet", +... key="document_id", +... fields=["document_id", "text"], +... ) +| Table | Size | +|----------------|------| +| documents | 51 | +| bm25_documents | 51 | ``` diff --git a/docs/api/overview.md b/docs/api/overview.md index f2ada58..447e79d 100644 --- a/docs/api/overview.md +++ b/docs/api/overview.md @@ -25,6 +25,7 @@ ## tables +- [add_columns_documents](../tables/add-columns-documents) - [create_documents](../tables/create-documents) - [create_documents_queries](../tables/create-documents-queries) - [create_queries](../tables/create-queries) @@ -43,6 +44,15 @@ ## utils + +**Classes** + +- [ParallelTqdm](../utils/ParallelTqdm) + +**Functions** + - [batchify](../utils/batchify) +- [generate_random_hash](../utils/generate-random-hash) +- [get_list_columns_df](../utils/get-list-columns-df) - [plot](../utils/plot) diff --git a/docs/api/search/documents.md b/docs/api/search/documents.md index 50fb956..5664aed 100644 --- a/docs/api/search/documents.md +++ b/docs/api/search/documents.md @@ -38,6 +38,12 @@ Search for documents in the documents table using specified queries. Optional SQL filters to apply during the search. +- **order_by** (*str | None*) – defaults to `None` + +- **tqdm_bar** (*bool*) – defaults to `True` + + Whether to display a progress bar when searching. + ## Examples diff --git a/docs/api/search/graphs.md b/docs/api/search/graphs.md index bb2339a..5567fc6 100644 --- a/docs/api/search/graphs.md +++ b/docs/api/search/graphs.md @@ -38,6 +38,8 @@ Search for graphs in DuckDB using the provided queries. Optional SQL filters to apply during the search. +- **tqdm_bar** (*bool*) – defaults to `True` + ## Examples @@ -53,17 +55,22 @@ Search for graphs in DuckDB using the provided queries. ... fields=["title", "text"], ... documents=documents, ... ) +| Table | Size | +|----------------|------| +| documents | 5183 | +| bm25_documents | 5183 | >>> upload.queries( ... database="test.duckdb", ... queries=queries, ... documents_queries=qrels, ... ) - ->>> scores = search.graphs( -... database="test.duckdb", -... queries=queries, -... top_k=10, -... ) +| Table | Size | +|-------------------|------| +| documents | 5183 | +| queries | 807 | +| bm25_documents | 5183 | +| bm25_queries | 807 | +| documents_queries | 916 | ``` diff --git a/docs/api/search/queries.md b/docs/api/search/queries.md index bd2c8fe..9f71712 100644 --- a/docs/api/search/queries.md +++ b/docs/api/search/queries.md @@ -38,6 +38,8 @@ Search for queries in the queries table using specified queries. Optional SQL filters to apply during the search. +- **tqdm_bar** (*bool*) – defaults to `True` + ## Examples diff --git a/docs/api/search/search.md b/docs/api/search/search.md index 31514f1..fde5e7a 100644 --- a/docs/api/search/search.md +++ b/docs/api/search/search.md @@ -50,6 +50,12 @@ Run the search for documents or queries in parallel. Optional SQL filters to apply during the search. +- **order_by** (*str | None*) – defaults to `None` + +- **tqdm_bar** (*bool*) – defaults to `True` + + Whether to display a progress bar when searching. + ## Examples @@ -67,7 +73,6 @@ Run the search for documents or queries in parallel. ... top_k=10, ... ) ->>> assert len(documents) == 1 ->>> assert len(documents[0]) == 10 +>>> assert len(documents) == 10 ``` diff --git a/docs/api/search/update-index-documents.md b/docs/api/search/update-index-documents.md index 35800b2..99fa334 100644 --- a/docs/api/search/update-index-documents.md +++ b/docs/api/search/update-index-documents.md @@ -10,6 +10,10 @@ Update the BM25 search index for documents. The name of the DuckDB database. +- **fields** (*list[str]*) + + The fields to index for each document. + - **k1** (*float*) – defaults to `1.5` The BM25 k1 parameter, controls term saturation. diff --git a/docs/api/tables/add-columns-documents.md b/docs/api/tables/add-columns-documents.md new file mode 100644 index 0000000..55d900c --- /dev/null +++ b/docs/api/tables/add-columns-documents.md @@ -0,0 +1,21 @@ +# add_columns_documents + +Add columns to the documents table in the DuckDB database. + + + +## Parameters + +- **database** (*str*) + +- **schema** (*str*) + +- **columns** (*list[str] | str*) + +- **dtypes** (*dict*) – defaults to `None` + +- **config** (*dict*) – defaults to `None` + + + + diff --git a/docs/api/tables/create-documents.md b/docs/api/tables/create-documents.md index 1a1201a..ca152af 100644 --- a/docs/api/tables/create-documents.md +++ b/docs/api/tables/create-documents.md @@ -10,7 +10,7 @@ Create the documents table in the DuckDB database. - **schema** (*str*) -- **fields** (*str | list[str]*) +- **columns** (*str | list[str]*) - **dtypes** (*dict[str, str] | None*) – defaults to `None` @@ -31,7 +31,7 @@ Create the documents table in the DuckDB database. >>> tables.create_documents( ... database="test.duckdb", ... schema="bm25_tables", -... fields=["title", "text"], +... columns=["title", "text"], ... dtypes={"text": "VARCHAR", "title": "VARCHAR"}, ... ) @@ -46,7 +46,7 @@ Create the documents table in the DuckDB database. ... schema="bm25_tables", ... key="id", ... df=df, -... fields=["title", "text"], +... columns=["title", "text"], ... ) ``` diff --git a/docs/api/tables/insert-documents.md b/docs/api/tables/insert-documents.md index 5cd64eb..dad119c 100644 --- a/docs/api/tables/insert-documents.md +++ b/docs/api/tables/insert-documents.md @@ -22,7 +22,7 @@ Insert documents into the documents table with optional multi-threading. The field that uniquely identifies each document (e.g., 'id'). -- **fields** (*list[str] | str*) +- **columns** (*list[str] | str*) The list of document fields to insert. Can be a string if inserting a single field. @@ -61,7 +61,7 @@ Insert documents into the documents table with optional multi-threading. ... database="test.duckdb", ... schema="bm25_tables", ... key="id", -... fields=["title", "text"], +... columns=["title", "text"], ... df=df ... ) ``` diff --git a/docs/api/tables/select-documents.md b/docs/api/tables/select-documents.md index 0a5bffd..97ee634 100644 --- a/docs/api/tables/select-documents.md +++ b/docs/api/tables/select-documents.md @@ -4,6 +4,22 @@ Select all documents from the documents table. +## Parameters + +- **database** (*str*) + + The name of the DuckDB database. + +- **schema** (*str*) + + The schema where the documents table is located. + +- **limit** (*int | None*) – defaults to `None` + +- **config** (*dict | None*) – defaults to `None` + + Optional configuration options for the DuckDB connection. + ## Examples diff --git a/docs/api/upload/documents.md b/docs/api/upload/documents.md index 60691a6..87bb326 100644 --- a/docs/api/upload/documents.md +++ b/docs/api/upload/documents.md @@ -64,6 +64,10 @@ Upload documents to DuckDB, create necessary schema, and index using BM25. - **limit** (*int | None*) – defaults to `None` +- **tqdm_bar** (*bool*) – defaults to `True` + + Whether to display a progress bar when uploading documents + diff --git a/docs/api/utils/ParallelTqdm.md b/docs/api/utils/ParallelTqdm.md new file mode 100644 index 0000000..645b327 --- /dev/null +++ b/docs/api/utils/ParallelTqdm.md @@ -0,0 +1,80 @@ +# ParallelTqdm + +joblib.Parallel, but with a tqdm progressbar. + + + +## Parameters + +- **total** (*int*) + + The total number of tasks to complete. + +- **desc** (*str*) + + A description of the task. + +- **tqdm_bar** (*bool*) – defaults to `True` + + Whether to display a tqdm progress bar. Default is False. + +- **show_joblib_header** (*bool*) – defaults to `False` + + Whether to display the joblib header. Default is False + +- **kwargs** + + + + +## Methods + +???- note "__call__" + + Main function to dispatch parallel tasks. + + **Parameters** + + - **iterable** + +???- note "debug" + +???- note "dispatch_next" + + Dispatch more data for parallel processing + + This method is meant to be called concurrently by the multiprocessing callback. We rely on the thread-safety of dispatch_one_batch to protect against concurrent consumption of the unprotected iterator. + + +???- note "dispatch_one_batch" + + Prefetch the tasks for the next batch and dispatch them. + + The effective size of the batch is computed here. If there are no more jobs to dispatch, return False, else return True. The iterator consumption and dispatching is protected by the same lock so calling this function should be thread safe. + + **Parameters** + + - **iterator** + +???- note "format" + + Return the formatted representation of the object. + + **Parameters** + + - **obj** + - **indent** – defaults to `0` + +???- note "info" + +???- note "print_progress" + + Display the process of the parallel execution using tqdm + + +???- note "warn" + +## References + +https://github.com/joblib/joblib/issues/972 + diff --git a/docs/api/utils/generate-random-hash.md b/docs/api/utils/generate-random-hash.md new file mode 100644 index 0000000..aae3cbb --- /dev/null +++ b/docs/api/utils/generate-random-hash.md @@ -0,0 +1,9 @@ +# generate_random_hash + +Generate a random SHA-256 hash. + + + + + + diff --git a/docs/api/utils/get-list-columns-df.md b/docs/api/utils/get-list-columns-df.md new file mode 100644 index 0000000..6599a20 --- /dev/null +++ b/docs/api/utils/get-list-columns-df.md @@ -0,0 +1,13 @@ +# get_list_columns_df + +Get a list of columns from a list of dictionaries or a DataFrame. + + + +## Parameters + +- **documents** (*list[dict] | pandas.core.frame.DataFrame*) + + + + diff --git a/docs/index.md b/docs/index.md index 7636083..3813b65 100644 --- a/docs/index.md +++ b/docs/index.md @@ -13,9 +13,7 @@

-DuckSearch is a lightweight and easy-to-use library to search documents. DuckSearch is built on top of DuckDB, a high-performance analytical database. DuckDB is designed to execute analytical SQL queries fast, and DuckSearch leverages this to provide efficient search and filtering features. DuckSearch index can be updated with new documents and documents can be deleted as well. - -DuckSearch also supports HuggingFace datasets, allowing to index datasets directly from the HuggingFace Hub. +DuckSearch is a lightweight and easy-to-use library to search documents. DuckSearch is built on top of DuckDB, a high-performance analytical database. DuckDB is designed to execute analytical SQL queries fast, and DuckSearch leverages this to provide efficient search and filtering features. DuckSearch index can be updated with new documents and documents can be deleted as well. DuckSearch also supports HuggingFace datasets, allowing to index datasets directly from the HuggingFace Hub.

## Installation @@ -63,8 +61,8 @@ documents = [ upload.documents( database="ducksearch.duckdb", - key="id", # unique document identifier - fields=["title", "style", "date", "popularity"], # list of fields to index + key="id", # Unique document identifier + fields=["title", "style"], # List of fields to use for search. documents=documents, dtypes={ "date": "DATE", @@ -75,7 +73,7 @@ upload.documents( ## Search -`search.documents` returns a list of list of documents ordered by relevance. We can control the number of documents to return using the `top_k` parameter. The following example demonstrates how to search for documents with the queries "punk" and "california" while filtering the results to include only documents with a date after 1970 and a popularity score greater than 8. +`search.documents` returns a list of list of documents ordered by relevance. We can control the number of documents to return using the `top_k` parameter. The following example demonstrates how to search for documents with the queries "punk" and "california" while filtering the results to include only documents with a date after 1970 and a popularity score greater than 8. We will order the results by a weighted sum of the BM25 score and the popularity score provided in the document. ```python from ducksearch import search @@ -85,6 +83,7 @@ search.documents( queries=["punk", "california"], top_k=10, filters="YEAR(date) >= 1970 AND popularity > 8", + order_by="0.8 * score + 0.2 * popularity DESC", ) ``` @@ -115,6 +114,8 @@ search.documents( Filters are SQL expressions that are applied to the search results. We can use every filtering function DuckDB provides such as [date functions](https://duckdb.org/docs/sql/functions/date). +Both `filters` and `order_by` parameters are optional. If not provided, the results are ordered by BM25 relevance and no filters are applied. + ## Delete and update index We can delete documents and update the BM25 weights accordingly using the `delete.documents` function. @@ -134,8 +135,7 @@ To update the index, we should first delete the documents and then upload the up ### HuggingFace -The `upload.documents` function can also index HuggingFace datasets directly from the url. -The following example demonstrates how to index the FineWeb dataset from HuggingFace: +The `upload.documents` function can also index HuggingFace datasets directly from the url. The following example demonstrates how to index the FineWeb dataset from HuggingFace. We will use the fields "text" and "url" for search. We will also specify the data types for the "date", "token_count", and "language_score" fields to be able to filter the results. ```python from ducksearch import upload @@ -143,53 +143,78 @@ from ducksearch import upload upload.documents( database="fineweb.duckdb", key="id", - fields=["text", "url", "date", "language", "token_count", "language_score"], + fields=["text", "url"], documents="https://huggingface.co/datasets/HuggingFaceFW/fineweb/resolve/main/sample/10BT/000_00000.parquet", dtypes={ "date": "DATE", "token_count": "INT", "language_score": "FLOAT", }, - limit=1000, # demonstrate with a small dataset + limit=3000, # demonstrate with a small dataset ) ``` -We can then search the FineWeb dataset with the `search.documents` function: +We can then search the FineWeb dataset with the `search.documents` function. We order the results by BM25 score and then date. ```python from ducksearch import search search.documents( database="fineweb.duckdb", - queries="earth science", + queries=["earth science"], top_k=2, + order_by="score DESC, date DESC", ) ``` ```python [ - { - "id": "", - "text": "Earth Science Tutors in Rowland ...", - "date": Timestamp("2017-08-19 00:00:00"), - "language": "en", - "token_count": 313, - "language_score": 0.8718525171279907, - "score": 1.1588547229766846, - }, - { - "score": 1.6727683544158936, - "id": "", - "text": "The existing atmosphere surrounding the earth contains ...", - "url": "http://www.accuracyingenesis.com/atmargon.html", - "date": Timestamp("2015-04-02 00:00:00"), - "language": "en", - "token_count": 1348, - "language_score": 0.9564403295516968, - }, + [ + { + "id": "", + "text": "Earth Science Tutors in Rowland...", + "id_1": "", + "dump": "CC-MAIN-2017-34", + "url": "http://rowland.universitytutor.com/rowland_earth-science-tutoring", + "date": Timestamp("2017-08-19 00:00:00"), + "file_path": "s3://commoncrawl/crawl-data/CC-MAIN-2017-34/segments/1502886105304.35/warc/CC-MAIN-20170819051034-20170819071034-00240.warc.gz", + "language": "en", + "language_score": 0.8718525171279907, + "token_count": 313, + "bm25id": 523, + "score": 2.3761106729507446, + }, + { + "id": "", + "text": "- Geomagnetic field....", + "id_1": "", + "dump": "CC-MAIN-2022-21", + "url": "https://www.imperial.ac.uk/people/adrian.muxworthy/?respub-action=citation.html&id=1149861&noscript=noscript", + "date": Timestamp("2022-05-20 00:00:00"), + "file_path": "s3://commoncrawl/crawl-data/CC-MAIN-2022-21/segments/1652662530553.34/warc/CC-MAIN-20220519235259-20220520025259-00601.warc.gz", + "language": "en", + "language_score": 0.8225595951080322, + "token_count": 517, + "bm25id": 4783, + "score": 2.3569871187210083, + }, + ] ] + ``` +Note: by default, results are ordered by BM25 relevance. + +## Tables + +Ducksearch creates two distinct schemas: `bm25_tables`, `bm25_documents`. + +- We can find the uploaded documents in the `bm25_tables.documents` table. + +- We can find the inverted index in the `bm25_documents.scores` table. You can update the scores as you wish. Just note that tokens scores will be updated each time you upload documents (every tokens scores mentionned in the set of uploaded documents). + +- We can update the set of stopwords in the `bm25_documents.stopwords` table. + ## Benchmark @@ -210,6 +235,12 @@ search.documents( | trec-covid | 0.9533 | 1.0 | 9.4800 | 1.0 | 0.0074 | 0.0077 | 112.38 | 22.15 | 50 queries, 171K documents | | webis-touche2020 | 0.4130 | 0.5510 | 3.7347 | 0.7114 | 0.0564 | 0.0827 | 104.65 | 44.14 | 49 queries, 382K documents | +## References + +- [DuckDB](https://duckdb.org/) + +- [DuckDB Full Text Search](https://duckdb.org/docs/extensions/full_text_search.html): Note that DuckSearch rely partially on the DuckDB Full Text Search extension but accelerate the search process via `top_k_token` approximation, pre-computation of scores and multi-threading. + ## License DuckSearch is released under the MIT license. diff --git a/ducksearch/__version__.py b/ducksearch/__version__.py index 4d480bb..30dbba5 100644 --- a/ducksearch/__version__.py +++ b/ducksearch/__version__.py @@ -1,3 +1,3 @@ -VERSION = (1, 0, 1) +VERSION = (1, 0, 2) __version__ = ".".join(map(str, VERSION)) diff --git a/ducksearch/decorators/execute_with_duckdb.py b/ducksearch/decorators/execute_with_duckdb.py index b529799..fcb3ad3 100644 --- a/ducksearch/decorators/execute_with_duckdb.py +++ b/ducksearch/decorators/execute_with_duckdb.py @@ -1,4 +1,5 @@ import pathlib +import time from functools import wraps import duckdb @@ -8,8 +9,11 @@ def connect_to_duckdb( database: str, read_only: bool = False, config: dict | None = None, + max_retry: int = 30, + sleep_time: float = 0.1, + **kwargs, ): - """Establish a connection to the DuckDB database. + """Establish a connection to the DuckDB database. Retry connecting if an error occurs. Parameters ---------- @@ -19,6 +23,10 @@ def connect_to_duckdb( Whether to open the database in read-only mode. Default is False. config Optional configuration settings for the DuckDB connection. + max_retry + The maximum number of times to retry connecting to DuckDB. + sleep_time + The time to sleep between retries. Returns ------- @@ -26,11 +34,22 @@ def connect_to_duckdb( A DuckDB connection object. """ - return ( - duckdb.connect(database=database, read_only=read_only, config=config) - if config - else duckdb.connect(database=database, read_only=read_only) - ) + current_retry = 0 + while True: + try: + conn = ( + duckdb.connect(database=database, read_only=read_only, config=config) + if config + else duckdb.connect(database=database, read_only=read_only) + ) + break + except Exception as error: + if current_retry >= max_retry: + raise error + time.sleep(sleep_time) + current_retry += 1 + + return conn def execute_with_duckdb( @@ -73,7 +92,10 @@ def wrapper( ): """Connect to DuckDB and execute the query from the provided SQL file path(s).""" conn = connect_to_duckdb( - database=database, read_only=read_only, config=config + database=database, + read_only=read_only, + config=config, + **kwargs, ) # Ensure relative_path is treated as a list diff --git a/ducksearch/delete/delete/scores.sql b/ducksearch/delete/delete/scores.sql index a52b430..e60a6af 100644 --- a/ducksearch/delete/delete/scores.sql +++ b/ducksearch/delete/delete/scores.sql @@ -1,15 +1,13 @@ -- This query finds the set of tokens scores for which there won't be any docid / score to keep. WITH _docs_to_delete AS ( - SELECT DISTINCT - bm25.docid - FROM parquet_scan('{parquet_file}') p - INNER JOIN bm25_documents.docs bm25 + SELECT DISTINCT bm25.docid + FROM parquet_scan('{parquet_file}') AS p + INNER JOIN bm25_documents.docs AS bm25 ON p.id = bm25.name ), _terms_to_recompute AS ( - SELECT DISTINCT - term + SELECT DISTINCT term FROM bm25_documents.terms INNER JOIN _docs_to_delete ON bm25_documents.terms.docid = _docs_to_delete.docid @@ -22,16 +20,16 @@ _scores_to_update AS ( _bm25.term, _bm25.list_scores, _bm25.list_docids - FROM bm25_documents.scores _bm25 - INNER JOIN _terms_to_recompute _terms + FROM bm25_documents.scores AS _bm25 + INNER JOIN _terms_to_recompute AS _terms ON _bm25.term = _terms.term ), _unested_scores AS ( SELECT term, - UNNEST(list_scores) AS score, - UNNEST(list_docids) AS docid + unnest(list_scores) AS score, + unnest(list_docids) AS docid FROM _scores_to_update ), @@ -41,8 +39,8 @@ _unested_unfiltered_scores AS ( _scores.docid, _scores.score, _docs.docid AS to_delete - FROM _unested_scores _scores - LEFT JOIN _docs_to_delete _docs + FROM _unested_scores AS _scores + LEFT JOIN _docs_to_delete AS _docs ON _scores.docid = _docs.docid ), @@ -59,18 +57,17 @@ _terms_to_delete AS ( SELECT DISTINCT ttr.term, ufs.term AS missing - FROM _terms_to_recompute ttr - LEFT JOIN _unested_filtered_scores ufs + FROM _terms_to_recompute AS ttr + LEFT JOIN _unested_filtered_scores AS ufs ON ttr.term = ufs.term ), _scores_to_delete_completely AS ( - SELECT DISTINCT - term, + SELECT DISTINCT term FROM _terms_to_delete WHERE missing IS NULL ) -DELETE FROM bm25_documents.scores as _scores -USING _scores_to_delete_completely as _scores_to_delete -WHERE _scores.term = _scores_to_delete.term; \ No newline at end of file +DELETE FROM bm25_documents.scores AS _scores +USING _scores_to_delete_completely AS _scores_to_delete +WHERE _scores.term = _scores_to_delete.term; diff --git a/ducksearch/delete/update/df.sql b/ducksearch/delete/update/df.sql index 60926d4..8a4137d 100644 --- a/ducksearch/delete/update/df.sql +++ b/ducksearch/delete/update/df.sql @@ -1,22 +1,21 @@ WITH _docs_to_delete AS ( - SELECT DISTINCT - bm25.docid - FROM parquet_scan('{parquet_file}') p - INNER JOIN bm25_documents.docs bm25 + SELECT DISTINCT bm25.docid + FROM parquet_scan('{parquet_file}') AS p + INNER JOIN bm25_documents.docs AS bm25 ON p.id = bm25.name ), _tf AS ( SELECT termid, - sum(tf) as df + sum(tf) AS df FROM bm25_documents.terms INNER JOIN _docs_to_delete - ON bm25_documents.terms.docid = _docs_to_delete.docid + ON bm25_documents.terms.docid = _docs_to_delete.docid GROUP BY 1 ) UPDATE bm25_documents.dict _dict -SET df = GREATEST(_dict.df - _tf.df, 0) +SET df = greatest(_dict.df - _tf.df, 0) FROM _tf -WHERE _dict.termid = _tf.termid; \ No newline at end of file +WHERE _dict.termid = _tf.termid; diff --git a/ducksearch/delete/update/docs.sql b/ducksearch/delete/update/docs.sql index 43afe39..2fffdb5 100644 --- a/ducksearch/delete/update/docs.sql +++ b/ducksearch/delete/update/docs.sql @@ -1,3 +1,3 @@ -DELETE FROM bm25_documents.docs as _docs +DELETE FROM bm25_documents.docs AS _docs USING parquet_scan('{parquet_file}') AS _df_documents -WHERE _docs.name = _df_documents.id; \ No newline at end of file +WHERE _docs.name = _df_documents.id; diff --git a/ducksearch/delete/update/scores.sql b/ducksearch/delete/update/scores.sql index aef82f4..1e67658 100644 --- a/ducksearch/delete/update/scores.sql +++ b/ducksearch/delete/update/scores.sql @@ -1,15 +1,13 @@ -- This query finds the set of tokens scores for which there won't be any docid / score to keep. WITH _docs_to_delete AS ( - SELECT DISTINCT - bm25.docid - FROM parquet_scan('{parquet_file}') p - INNER JOIN bm25_documents.docs bm25 + SELECT DISTINCT bm25.docid + FROM parquet_scan('{parquet_file}') AS p + INNER JOIN bm25_documents.docs AS bm25 ON p.id = bm25.name ), _terms_to_recompute AS ( - SELECT DISTINCT - term + SELECT DISTINCT term FROM bm25_documents.terms INNER JOIN _docs_to_delete ON bm25_documents.terms.docid = _docs_to_delete.docid @@ -22,16 +20,16 @@ _scores_to_update AS ( _bm25.term, _bm25.list_scores, _bm25.list_docids - FROM bm25_documents.scores _bm25 - INNER JOIN _terms_to_recompute _terms + FROM bm25_documents.scores AS _bm25 + INNER JOIN _terms_to_recompute AS _terms ON _bm25.term = _terms.term ), _unested_scores AS ( SELECT term, - UNNEST(list_scores) AS score, - UNNEST(list_docids) AS docid + unnest(list_scores) AS score, + unnest(list_docids) AS docid FROM _scores_to_update ), @@ -41,8 +39,8 @@ _unested_unfiltered_scores AS ( _scores.docid, _scores.score, _docs.docid AS to_delete - FROM _unested_scores _scores - LEFT JOIN _docs_to_delete _docs + FROM _unested_scores AS _scores + LEFT JOIN _docs_to_delete AS _docs ON _scores.docid = _docs.docid ), @@ -58,8 +56,8 @@ _unested_filtered_scores AS ( _list_scores AS ( SELECT term, - LIST(docid ORDER BY score DESC, docid ASC) AS list_docids, - LIST(score ORDER BY score DESC, docid ASC) AS list_scores + list(docid ORDER BY score DESC, docid ASC) AS list_docids, + list(score ORDER BY score DESC, docid ASC) AS list_scores FROM _unested_filtered_scores GROUP BY 1 ) @@ -68,5 +66,5 @@ UPDATE bm25_documents.scores s SET list_docids = u.list_docids, list_scores = u.list_scores -FROM _list_scores u -WHERE s.term = u.term; \ No newline at end of file +FROM _list_scores AS u +WHERE s.term = u.term; diff --git a/ducksearch/delete/update/stats.sql b/ducksearch/delete/update/stats.sql index 364f6c9..9d4054c 100644 --- a/ducksearch/delete/update/stats.sql +++ b/ducksearch/delete/update/stats.sql @@ -2,10 +2,11 @@ WITH _stats AS ( SELECT COUNT(*) AS num_docs, AVG(len) AS avgdl - FROM bm25_documents.docs + FROM bm25_documents.docs ) -UPDATE bm25_documents.stats -SET num_docs = _stats.num_docs, +UPDATE bm25_documents.stats +SET + num_docs = _stats.num_docs, avgdl = _stats.avgdl -FROM _stats; \ No newline at end of file +FROM _stats; diff --git a/ducksearch/delete/update/terms.sql b/ducksearch/delete/update/terms.sql index 8907a6c..be48045 100644 --- a/ducksearch/delete/update/terms.sql +++ b/ducksearch/delete/update/terms.sql @@ -1,11 +1,10 @@ WITH _docs_to_delete AS ( - SELECT - bm25.docid - FROM parquet_scan('{parquet_file}') p - INNER JOIN bm25_documents.docs bm25 + SELECT bm25.docid + FROM parquet_scan('{parquet_file}') AS p + INNER JOIN bm25_documents.docs AS bm25 ON p.id = bm25.name ) -DELETE FROM bm25_documents.terms as _terms -USING _docs_to_delete as _docs -WHERE _terms.docid = _docs.docid; \ No newline at end of file +DELETE FROM bm25_documents.terms AS _terms +USING _docs_to_delete AS _docs +WHERE _terms.docid = _docs.docid; diff --git a/ducksearch/hf/drop/tmp.sql b/ducksearch/hf/drop/tmp.sql new file mode 100644 index 0000000..4acce8c --- /dev/null +++ b/ducksearch/hf/drop/tmp.sql @@ -0,0 +1 @@ +DROP TABLE {schema}._hf_tmp; diff --git a/ducksearch/hf/insert.py b/ducksearch/hf/insert.py index fd81c7e..10d4eb8 100644 --- a/ducksearch/hf/insert.py +++ b/ducksearch/hf/insert.py @@ -1,4 +1,5 @@ from ..decorators import execute_with_duckdb +from ..tables import add_columns_documents, create_documents @execute_with_duckdb( @@ -9,14 +10,48 @@ def _insert_documents() -> None: """Insert the documents from Hugging Face datasets into DuckDB.""" +@execute_with_duckdb( + relative_path="hf/select/columns.sql", + fetch_df=True, + read_only=True, +) +def _select_columns() -> None: + """Select all columns from the HuggingFace documents table.""" + + +@execute_with_duckdb( + relative_path="hf/select/exists.sql", + fetch_df=True, + read_only=True, +) +def _table_exists() -> None: + """Check if the table exists in the DuckDB database.""" + + +@execute_with_duckdb( + relative_path="hf/insert/tmp.sql", + fetch_df=False, +) +def _insert_tmp_documents() -> None: + """Insert the documents from Hugging Face datasets into DuckDB.""" + + +@execute_with_duckdb( + relative_path="hf/drop/tmp.sql", + fetch_df=True, +) +def _drop_tmp_table() -> None: + """Drop the temporary HF table.""" + + def insert_documents( database: str, schema: str, key: str, - fields: str | list[str], url: str, config: dict | None = None, limit: int | None = None, + dtypes: dict | None = None, ) -> None: """Insert documents from a Hugging Face dataset into DuckDB. @@ -41,32 +76,96 @@ def insert_documents( >>> upload.documents( ... database="test.duckdb", - ... documents="hf://datasets/lightonai/lighton-ms-marco-mini/train.parquet", - ... fields=["document_ids", "scores"], + ... documents="hf://datasets/lightonai/lighton-ms-marco-mini/queries.parquet", ... key="query_id", + ... fields=["query_id", "text"], ... ) | Table | Size | |----------------|------| | documents | 19 | | bm25_documents | 19 | + >>> upload.documents( + ... database="test.duckdb", + ... documents="hf://datasets/lightonai/lighton-ms-marco-mini/documents.parquet", + ... key="document_id", + ... fields=["document_id", "text"], + ... ) + | Table | Size | + |----------------|------| + | documents | 51 | + | bm25_documents | 51 | + """ - if isinstance(fields, str): - fields = [fields] + limit_hf = f"LIMIT {limit}" if limit is not None else "" + + _insert_tmp_documents( + database=database, + schema=schema, + url=url, + key_field=key, + config=config, + limit_hf=limit_hf, + ) - fields = [field for field in fields if field != "id"] + exists = _table_exists( + database=database, + schema=schema, + table_name="documents", + )[0]["table_exists"] - limit_hf = f"LIMIT {limit}" if limit is not None else "" + _hf_tmp_columns = _select_columns( + database=database, + schema=schema, + table_name="_hf_tmp", + ) + + _hf_tmp_columns = [ + column["column"] for column in _hf_tmp_columns if column["column"] != "id" + ] + + if exists: + documents_columns = _select_columns( + database=database, + schema=schema, + table_name="documents", + ) + + documents_columns = set( + [column["column"] for column in documents_columns if column != "id"] + ) - if not fields: - fields.append(key) + columns_to_add = list(set(_hf_tmp_columns) - documents_columns) - return _insert_documents( + if columns_to_add: + add_columns_documents( + database=database, + schema=schema, + columns=columns_to_add, + dtypes=dtypes, + config=config, + ) + else: + create_documents( + database=database, + schema=schema, + columns=_hf_tmp_columns, + dtypes=dtypes, + config=config, + ) + + _insert_documents( database=database, schema=schema, url=url, key_field=key, - fields=", ".join(fields), + _hf_tmp_columns=", ".join(_hf_tmp_columns), limit_hf=limit_hf, config=config, ) + + _drop_tmp_table( + database=database, + schema=schema, + config=config, + ) diff --git a/ducksearch/hf/insert/documents.sql b/ducksearch/hf/insert/documents.sql index 731c754..411b29e 100644 --- a/ducksearch/hf/insert/documents.sql +++ b/ducksearch/hf/insert/documents.sql @@ -1,32 +1,22 @@ -INSERT INTO {schema}.documents (id, {fields}) ( +INSERT INTO {schema}.documents (id, {_hf_tmp_columns}) ( WITH _hf_dataset AS ( SELECT - {key_field} AS id, - {fields} - FROM '{url}' - {limit_hf} - ), - - _hf_row_number AS ( - SELECT - *, - ROW_NUMBER() OVER (PARTITION BY id ORDER BY id, RANDOM()) AS _row_number - FROM _hf_dataset + id, + * EXCLUDE (id) + FROM {schema}._hf_tmp ), _new_hf_dataset AS ( SELECT - _hf_row_number.*, + _hf_dataset.*, d.id AS existing_id - FROM _hf_row_number + FROM _hf_dataset LEFT JOIN {schema}.documents AS d - ON _hf_row_number.id = d.id - WHERE _row_number = 1 + ON _hf_dataset.id = d.id ) - SELECT id, {fields} + SELECT id, {_hf_tmp_columns} FROM _new_hf_dataset - WHERE _row_number = 1 - AND existing_id IS NULL + WHERE existing_id IS NULL ); diff --git a/ducksearch/hf/insert/tmp.sql b/ducksearch/hf/insert/tmp.sql new file mode 100644 index 0000000..6597580 --- /dev/null +++ b/ducksearch/hf/insert/tmp.sql @@ -0,0 +1,20 @@ +CREATE OR REPLACE TABLE {schema}._hf_tmp AS ( + WITH _hf_dataset AS ( + SELECT + {key_field} AS id, + * + FROM '{url}' + {limit_hf} + ), + + _hf_row_number AS ( + SELECT + *, + ROW_NUMBER() OVER (PARTITION BY id ORDER BY id, RANDOM()) AS _row_number + FROM _hf_dataset + ) + + SELECT * EXCLUDE (_row_number) + FROM _hf_row_number + WHERE _row_number = 1 +); diff --git a/ducksearch/hf/select/columns.sql b/ducksearch/hf/select/columns.sql new file mode 100644 index 0000000..f12fa22 --- /dev/null +++ b/ducksearch/hf/select/columns.sql @@ -0,0 +1,5 @@ +SELECT column_name as column +FROM information_schema.columns +WHERE + lower(table_name) = '{table_name}' + AND table_schema = '{schema}'; diff --git a/ducksearch/hf/select/exists.sql b/ducksearch/hf/select/exists.sql new file mode 100644 index 0000000..a1167d4 --- /dev/null +++ b/ducksearch/hf/select/exists.sql @@ -0,0 +1,7 @@ +SELECT EXISTS( + SELECT 1 + FROM information_schema.tables + WHERE + LOWER(table_name) = LOWER('{table_name}') + AND table_schema = '{schema}' +) AS table_exists; diff --git a/ducksearch/search/create.py b/ducksearch/search/create.py index a302a65..e43738f 100644 --- a/ducksearch/search/create.py +++ b/ducksearch/search/create.py @@ -5,7 +5,6 @@ import pyarrow.parquet as pq from ..decorators import execute_with_duckdb -from ..tables import select_documents_columns from ..utils import batchify @@ -398,6 +397,7 @@ def update_index( def update_index_documents( database: str, + fields: list[str], k1: float = 1.5, b: float = 0.75, stemmer: str = "porter", @@ -414,6 +414,8 @@ def update_index_documents( ---------- database The name of the DuckDB database. + fields + The fields to index for each document. k1 The BM25 k1 parameter, controls term saturation. b @@ -450,14 +452,6 @@ def update_index_documents( | bm25_documents | 5183 | """ - fields = ", ".join( - select_documents_columns( - database=database, - schema="bm25_tables", - config=config, - ) - ) - update_index( database=database, k1=k1, diff --git a/ducksearch/search/create/queries_index.sql b/ducksearch/search/create/queries_index.sql index 0dcaa87..00a7703 100644 --- a/ducksearch/search/create/queries_index.sql +++ b/ducksearch/search/create/queries_index.sql @@ -1,5 +1,5 @@ PRAGMA CREATE_FTS_INDEX( - '{schema}._queries', + '{schema}._queries_{random_hash}', 'query', 'query', STEMMER='{stemmer}', diff --git a/ducksearch/search/drop/queries.sql b/ducksearch/search/drop/queries.sql new file mode 100644 index 0000000..c69b8c6 --- /dev/null +++ b/ducksearch/search/drop/queries.sql @@ -0,0 +1,2 @@ +DROP SCHEMA fts_{schema}__queries_{random_hash} CASCADE; +DROP TABLE {schema}._queries_{random_hash}; diff --git a/ducksearch/search/graphs.py b/ducksearch/search/graphs.py index 1e34895..934c64a 100644 --- a/ducksearch/search/graphs.py +++ b/ducksearch/search/graphs.py @@ -5,10 +5,11 @@ import pyarrow as pa import pyarrow.parquet as pq -from joblib import Parallel, delayed +import tqdm +from joblib import delayed from ..decorators import execute_with_duckdb -from ..utils import batchify +from ..utils import ParallelTqdm, batchify, generate_random_hash from .create import _select_settings from .select import _create_queries_index, _insert_queries @@ -36,7 +37,8 @@ def _search_graph( queries: list[str], top_k: int, top_k_token: int, - index: int, + group_id: int, + random_hash: str, config: dict | None = None, filters: str | None = None, ) -> list: @@ -52,7 +54,7 @@ def _search_graph( The number of top results to retrieve for each query. top_k_token The number of top tokens to retrieve. Used to select top documents per token. - index + group_id The index of the current batch of queries. config Optional configuration settings for the DuckDB connection. @@ -68,24 +70,19 @@ def _search_graph( _search_graph_filters_query if filters is not None else _search_graph_query ) - index_table = pa.Table.from_pydict({"query": queries}) - pq.write_table(index_table, f"_queries_{index}.parquet", compression="snappy") - matchs = search_function( database=database, queries_schema="bm25_queries", documents_schema="bm25_documents", source_schema="bm25_tables", top_k=top_k, + group_id=group_id, + random_hash=random_hash, top_k_token=top_k_token, - parquet_file=f"_queries_{index}.parquet", filters=filters, config=config, ) - if os.path.exists(f"_queries_{index}.parquet"): - os.remove(f"_queries_{index}.parquet") - candidates = collections.defaultdict(list) for match in matchs: query = match.pop("_query") @@ -102,6 +99,7 @@ def graphs( n_jobs: int = -1, config: dict | None = None, filters: str | None = None, + tqdm_bar: bool = True, ) -> list[dict]: """Search for graphs in DuckDB using the provided queries. @@ -159,11 +157,7 @@ def graphs( | bm25_queries | 807 | | documents_queries | 916 | - >>> scores = search.graphs( - ... database="test.duckdb", - ... queries=queries, - ... top_k=10, - ... ) + """ resource.setrlimit( @@ -174,18 +168,38 @@ def graphs( queries = [queries] logging.info("Indexing queries.") - index_table = pa.Table.from_pydict({"query": queries}) - pq.write_table(index_table, "_queries.parquet", compression="snappy") + random_hash = generate_random_hash() + + batchs = { + group_id: batch + for group_id, batch in enumerate( + iterable=batchify( + X=queries, batch_size=batch_size, desc="Searching", tqdm_bar=False + ) + ) + } + + parquet_file = f"_queries_{random_hash}.parquet" + pa_queries, pa_group_ids = [], [] + for group_id, batch_queries in batchs.items(): + pa_queries.extend(batch_queries) + pa_group_ids.extend([group_id] * len(batch_queries)) + + logging.info("Indexing queries.") + index_table = pa.Table.from_pydict({"query": pa_queries, "group_id": pa_group_ids}) + + pq.write_table(index_table, parquet_file, compression="snappy") _insert_queries( database=database, schema="bm25_documents", - parquet_file="_queries.parquet", + parquet_file=parquet_file, + random_hash=random_hash, config=config, ) - if os.path.exists("_queries.parquet"): - os.remove("_queries.parquet") + if os.path.exists(parquet_file): + os.remove(parquet_file) settings = _select_settings( database=database, schema="bm25_documents", config=config @@ -194,27 +208,55 @@ def graphs( _create_queries_index( database=database, schema="bm25_documents", + random_hash=random_hash, **settings, config=config, ) matchs = [] - for match in Parallel( - n_jobs=1 if len(queries) <= batch_size else n_jobs, backend="threading" - )( - delayed(_search_graph)( - database, - batch_queries, - top_k, - top_k_token, - index, - config, - filters, - ) - for index, batch_queries in enumerate( - batchify(queries, batch_size=batch_size, desc="Searching") - ) - ): - matchs.extend(match) + if n_jobs == 1 or len(batchs) == 1: + if tqdm_bar: + bar = tqdm.tqdm( + total=len(batchs), + position=0, + desc="Searching", + ) + + for group_id, batch_queries in batchs.items(): + matchs.extend( + _search_graph( + database=database, + queries=batch_queries, + top_k=top_k, + top_k_token=top_k_token, + group_id=group_id, + random_hash=random_hash, + config=config, + filters=filters, + ) + ) + if tqdm_bar: + bar.update(1) + else: + for match in ParallelTqdm( + n_jobs=n_jobs, + backend="threading", + total=len(batchs), + desc="Searching", + tqdm_bar=tqdm_bar, + )( + delayed(_search_graph)( + database, + batch_queries, + top_k, + top_k_token, + group_id, + random_hash, + config, + filters, + ) + for group_id, batch_queries in batchs.items() + ): + matchs.extend(match) return matchs diff --git a/ducksearch/search/insert/queries.sql b/ducksearch/search/insert/queries.sql index 180ea71..5101e03 100644 --- a/ducksearch/search/insert/queries.sql +++ b/ducksearch/search/insert/queries.sql @@ -1,5 +1,6 @@ -CREATE OR REPLACE TABLE {schema}._queries AS ( +CREATE OR REPLACE TABLE {schema}._queries_{random_hash} AS ( SELECT - query + query, + group_id FROM parquet_scan('{parquet_file}') ); diff --git a/ducksearch/search/select.py b/ducksearch/search/select.py index a0ce629..885dbd3 100644 --- a/ducksearch/search/select.py +++ b/ducksearch/search/select.py @@ -4,10 +4,11 @@ import pyarrow as pa import pyarrow.parquet as pq -from joblib import Parallel, delayed +import tqdm +from joblib import delayed from ..decorators import execute_with_duckdb -from ..utils import batchify +from ..utils import ParallelTqdm, batchify, generate_random_hash from .create import _select_settings @@ -18,6 +19,13 @@ def _create_queries_index() -> None: """Create an index for the queries table in the DuckDB database.""" +@execute_with_duckdb( + relative_path="search/drop/queries.sql", +) +def _delete_queries_index() -> None: + """Delete the queries index from the DuckDB database.""" + + @execute_with_duckdb( relative_path="search/insert/queries.sql", ) @@ -34,6 +42,15 @@ def _search_query(): """Perform a search on the documents or queries table in DuckDB.""" +@execute_with_duckdb( + relative_path="search/select/search_order_by.sql", + read_only=True, + fetch_df=True, +) +def _search_query_order_by(): + """Perform a search on the documents or queries table in DuckDB.""" + + @execute_with_duckdb( relative_path="search/select/search_filters.sql", read_only=True, @@ -52,6 +69,8 @@ def documents( n_jobs: int = -1, config: dict | None = None, filters: str | None = None, + order_by: str | None = None, + tqdm_bar: bool = True, ) -> list[list[dict]]: """Search for documents in the documents table using specified queries. @@ -73,6 +92,8 @@ def documents( Optional configuration for DuckDB connection settings. filters Optional SQL filters to apply during the search. + tqdm_bar + Whether to display a progress bar when searching. Returns ------- @@ -107,6 +128,8 @@ def documents( top_k_token=top_k_token, n_jobs=n_jobs, filters=filters, + order_by=order_by, + tqdm_bar=tqdm_bar, ) @@ -119,6 +142,7 @@ def queries( n_jobs: int = -1, config: dict | None = None, filters: str | None = None, + tqdm_bar: bool = True, ) -> list[list[dict]]: """Search for queries in the queries table using specified queries. @@ -169,6 +193,7 @@ def queries( top_k_token=top_k_token, n_jobs=n_jobs, filters=filters, + tqdm_bar=tqdm_bar, ) @@ -180,9 +205,11 @@ def _search( queries: list[str], top_k: int, top_k_token: int, - index: int, + group_id: int, + random_hash: str, config: dict | None = None, filters: str | None = None, + order_by: str | None = None, ) -> list: """Perform a search on the specified source table (documents or queries). @@ -216,8 +243,10 @@ def _search( """ search_function = _search_query_filters if filters is not None else _search_query - index_table = pa.Table.from_pydict({"query": queries}) - pq.write_table(index_table, f"_queries_{index}.parquet", compression="snappy") + if filters is None and order_by is not None: + search_function = _search_query_order_by + + order_by = f"ORDER BY {order_by}" if order_by is not None else "ORDER BY score DESC" matchs = search_function( database=database, @@ -226,21 +255,19 @@ def _search( source=source, top_k=top_k, top_k_token=top_k_token, - parquet_file=f"_queries_{index}.parquet", + random_hash=random_hash, + group_id=group_id, filters=filters, config=config, + order_by=order_by, ) - if os.path.exists(f"_queries_{index}.parquet"): - os.remove(f"_queries_{index}.parquet") - candidates = collections.defaultdict(list) for match in matchs: query = match.pop("_query") candidates[query].append(match) candidates = [candidates[query] for query in queries] - return candidates @@ -256,6 +283,8 @@ def search( n_jobs: int = -1, config: dict | None = None, filters: str | None = None, + order_by: str | None = None, + tqdm_bar: bool = True, ) -> list[list[dict]]: """Run the search for documents or queries in parallel. @@ -283,6 +312,8 @@ def search( Optional configuration for DuckDB connection settings. filters Optional SQL filters to apply during the search. + tqdm_bar + Whether to display a progress bar when searching. Returns ------- @@ -311,57 +342,116 @@ def search( queries = [queries] is_query_str = True - logging.info("Indexing queries.") - index_table = pa.Table.from_pydict({"query": queries}) - settings = _select_settings( database=database, schema=schema, config=config, )[0] + batchs = { + group_id: batch + for group_id, batch in enumerate( + iterable=batchify( + X=queries, batch_size=batch_size, desc="Searching", tqdm_bar=False + ) + ) + } + + pa_queries, pa_group_ids = [], [] + for group_id, batch_queries in batchs.items(): + pa_queries.extend(batch_queries) + pa_group_ids.extend([group_id] * len(batch_queries)) + + logging.info("Indexing queries.") + index_table = pa.Table.from_pydict({"query": pa_queries, "group_id": pa_group_ids}) + + random_hash = generate_random_hash() + parquet_file = f"_queries_{random_hash}.parquet" + pq.write_table( index_table, - "_queries.parquet", + parquet_file, compression="snappy", ) _insert_queries( database=database, schema=schema, - parquet_file="_queries.parquet", + parquet_file=parquet_file, + random_hash=random_hash, config=config, ) - if os.path.exists("_queries.parquet"): - os.remove("_queries.parquet") + if os.path.exists(path=parquet_file): + os.remove(path=parquet_file) _create_queries_index( database=database, schema=schema, + random_hash=random_hash, **settings, config=config, ) matchs = [] - - for match in Parallel(n_jobs=n_jobs, backend="threading")( - delayed(_search)( - database, - schema, - source_schema, - source, - batch_queries, - top_k, - top_k_token, - index, - config, - filters=filters, - ) - for index, batch_queries in enumerate( - batchify(queries, batch_size=batch_size, desc="Searching") - ) - ): - matchs.extend(match) + if n_jobs == 1 or len(batchs) == 1: + if tqdm_bar: + bar = tqdm.tqdm( + total=len(batchs), + position=0, + desc="Searching", + ) + + for group_id, batch_queries in batchs.items(): + matchs.extend( + _search( + database=database, + schema=schema, + source_schema=source_schema, + source=source, + queries=batch_queries, + top_k=top_k, + top_k_token=top_k_token, + group_id=group_id, + random_hash=random_hash, + config=config, + filters=filters, + order_by=order_by, + ) + ) + if tqdm_bar: + bar.update(1) + else: + for match in ParallelTqdm( + n_jobs=n_jobs, + backend="threading", + total=len(batchs), + desc="Searching", + tqdm_bar=tqdm_bar, + )( + delayed(_search)( + database, + schema, + source_schema, + source, + batch_queries, + top_k, + top_k_token, + group_id, + random_hash, + config, + filters, + order_by, + ) + for group_id, batch_queries in batchs.items() + ): + matchs.extend(match) + + _delete_queries_index( + database=database, + schema=schema, + random_hash=random_hash, + config=config, + ) return matchs[0] if is_query_str else matchs diff --git a/ducksearch/search/select/search.sql b/ducksearch/search/select/search.sql index b7f3932..c8d21e4 100644 --- a/ducksearch/search/select/search.sql +++ b/ducksearch/search/select/search.sql @@ -1,13 +1,20 @@ -WITH _input_queries AS ( +WITH group_queries AS ( + SELECT + query + FROM {schema}._queries_{random_hash} + WHERE group_id = {group_id} +), + + _input_queries AS ( SELECT pf.query, ftsdict.term - FROM parquet_scan('{parquet_file}') pf - JOIN fts_{schema}__queries.docs docs + FROM group_queries pf + JOIN fts_{schema}__queries_{random_hash}.docs docs ON pf.query = docs.name - JOIN fts_{schema}__queries.terms terms + JOIN fts_{schema}__queries_{random_hash}.terms terms ON docs.docid = terms.docid - JOIN fts_{schema}__queries.dict ftsdict + JOIN fts_{schema}__queries_{random_hash}.dict ftsdict ON terms.termid = ftsdict.termid ), diff --git a/ducksearch/search/select/search_filters.sql b/ducksearch/search/select/search_filters.sql index 51bdbcb..6f6f60d 100644 --- a/ducksearch/search/select/search_filters.sql +++ b/ducksearch/search/select/search_filters.sql @@ -1,13 +1,20 @@ -WITH _input_queries AS ( +WITH group_queries AS ( + SELECT + query + FROM {schema}._queries_{random_hash} + WHERE group_id = {group_id} +), + + _input_queries AS ( SELECT pf.query, ftsdict.term - FROM parquet_scan('{parquet_file}') pf - JOIN fts_{schema}__queries.docs docs + FROM group_queries pf + JOIN fts_{schema}__queries_{random_hash}.docs docs ON pf.query = docs.name - JOIN fts_{schema}__queries.terms terms + JOIN fts_{schema}__queries_{random_hash}.terms terms ON docs.docid = terms.docid - JOIN fts_{schema}__queries.dict ftsdict + JOIN fts_{schema}__queries_{random_hash}.dict ftsdict ON terms.termid = ftsdict.termid ), @@ -56,11 +63,12 @@ _partition_scores AS ( _query, _score AS score, * EXCLUDE (_score, _query), - RANK() OVER (PARTITION BY _query ORDER BY _score DESC) AS _row_number + RANK() OVER (PARTITION BY _query {order_by}, RANDOM() ASC) AS _row_number FROM _filtered_scores QUALIFY _row_number <= {top_k} ) SELECT * EXCLUDE (_row_number) -FROM _partition_scores; +FROM _partition_scores +{order_by}; diff --git a/ducksearch/search/select/search_graph.sql b/ducksearch/search/select/search_graph.sql index 24dba08..1d5db51 100644 --- a/ducksearch/search/select/search_graph.sql +++ b/ducksearch/search/select/search_graph.sql @@ -1,13 +1,20 @@ -WITH _input_queries AS ( +WITH group_queries AS ( + SELECT + query + FROM {documents_schema}._queries_{random_hash} + WHERE group_id = {group_id} +), + + _input_queries AS ( SELECT pf.query, ftsdict.term - FROM parquet_scan('{parquet_file}') pf - JOIN fts_{documents_schema}__queries.docs docs + FROM group_queries pf + JOIN fts_{documents_schema}__queries_{random_hash}.docs docs ON pf.query = docs.name - JOIN fts_{documents_schema}__queries.terms terms + JOIN fts_{documents_schema}__queries_{random_hash}.terms terms ON docs.docid = terms.docid - JOIN fts_{documents_schema}__queries.dict ftsdict + JOIN fts_{documents_schema}__queries_{random_hash}.dict ftsdict ON terms.termid = ftsdict.termid ), diff --git a/ducksearch/search/select/search_graph_filters.sql b/ducksearch/search/select/search_graph_filters.sql index 8b8f946..d504f34 100644 --- a/ducksearch/search/select/search_graph_filters.sql +++ b/ducksearch/search/select/search_graph_filters.sql @@ -1,13 +1,20 @@ -WITH _input_queries AS ( +WITH group_queries AS ( + SELECT + query + FROM {documents_schema}._queries_{random_hash} + WHERE group_id = {group_id} +), + + _input_queries AS ( SELECT pf.query, ftsdict.term - FROM parquet_scan('{parquet_file}') pf - JOIN fts_{documents_schema}__queries.docs docs + FROM group_queries pf + JOIN fts_{documents_schema}__queries_{random_hash}.docs docs ON pf.query = docs.name - JOIN fts_{documents_schema}__queries.terms terms + JOIN fts_{documents_schema}__queries_{random_hash}.terms terms ON docs.docid = terms.docid - JOIN fts_{documents_schema}__queries.dict ftsdict + JOIN fts_{documents_schema}__queries_{random_hash}.dict ftsdict ON terms.termid = ftsdict.termid ), diff --git a/ducksearch/search/select/search_order_by.sql b/ducksearch/search/select/search_order_by.sql new file mode 100644 index 0000000..4ae5262 --- /dev/null +++ b/ducksearch/search/select/search_order_by.sql @@ -0,0 +1,74 @@ +WITH group_queries AS ( + SELECT + query + FROM {schema}._queries_{random_hash} + WHERE group_id = {group_id} +), + + _input_queries AS ( + SELECT + pf.query, + ftsdict.term + FROM group_queries pf + JOIN fts_{schema}__queries_{random_hash}.docs docs + ON pf.query = docs.name + JOIN fts_{schema}__queries_{random_hash}.terms terms + ON docs.docid = terms.docid + JOIN fts_{schema}__queries_{random_hash}.dict ftsdict + ON terms.termid = ftsdict.termid +), + +_nested_matchs AS ( + SELECT + iq.query, + s.list_docids[0:{top_k_token}] as list_docids, + s.list_scores[0:{top_k_token}] as list_scores + FROM {schema}.scores s + INNER JOIN _input_queries iq + ON s.term = iq.term +), + +_matchs AS ( + SELECT + query, + UNNEST( + s.list_docids + ) AS bm25id, + UNNEST( + s.list_scores + ) AS score + FROM _nested_matchs s +), + +_matchs_scores AS ( + SELECT + query, + bm25id, + SUM(score) AS score + FROM _matchs + GROUP BY 1, 2 +), + +_match_scores_documents AS ( + SELECT + ms.query AS _query, + ms.bm25id, + ms.score, + s.* + FROM _matchs_scores ms + INNER JOIN {source_schema}.{source} s + ON ms.bm25id = s.bm25id +), + +_partition_scores AS ( + SELECT + *, + RANK() OVER (PARTITION BY _query {order_by}, RANDOM() ASC) AS rank + FROM _match_scores_documents + QUALIFY rank <= {top_k} +) + +SELECT + * +FROM _partition_scores +{order_by}; diff --git a/ducksearch/search/select/settings_exists.sql b/ducksearch/search/select/settings_exists.sql index 8308663..3584a02 100644 --- a/ducksearch/search/select/settings_exists.sql +++ b/ducksearch/search/select/settings_exists.sql @@ -1,10 +1,7 @@ -SELECT CASE - WHEN EXISTS ( - SELECT 1 - FROM information_schema.tables - WHERE table_name = 'settings' +SELECT coalesce(EXISTS ( + SELECT 1 + FROM information_schema.tables + WHERE + table_name = 'settings' AND table_schema = '{schema}' - ) - THEN TRUE - ELSE FALSE -END AS table_exists; +), FALSE) AS table_exists; diff --git a/ducksearch/tables/__init__.py b/ducksearch/tables/__init__.py index 86c6303..289edb7 100644 --- a/ducksearch/tables/__init__.py +++ b/ducksearch/tables/__init__.py @@ -14,6 +14,7 @@ select_documents_columns, select_queries, ) +from .update import add_columns_documents __all__ = [ "create_documents", @@ -26,4 +27,5 @@ "select_documents", "select_documents_columns", "select_queries", + "add_columns_documents", ] diff --git a/ducksearch/tables/create.py b/ducksearch/tables/create.py index e9c34e8..e8c5fee 100644 --- a/ducksearch/tables/create.py +++ b/ducksearch/tables/create.py @@ -64,7 +64,7 @@ def create_schema( def create_documents( database: str, schema: str, - fields: str | list[str], + columns: str | list[str], dtypes: dict[str, str] | None = None, config: dict | None = None, ) -> None: @@ -76,8 +76,8 @@ def create_documents( The name of the DuckDB database. schema: str The schema in which to create the documents table. - fields: str or list[str] - The list of fields for the documents table. If a string is provided, it will be converted into a list. + columns: str or list[str] + The list of columns for the documents table. If a string is provided, it will be converted into a list. dtypes: dict[str, str], optional A dictionary specifying field names as keys and their DuckDB types as values. Defaults to 'VARCHAR' if not provided. config: dict, optional @@ -95,7 +95,7 @@ def create_documents( >>> tables.create_documents( ... database="test.duckdb", ... schema="bm25_tables", - ... fields=["title", "text"], + ... columns=["title", "text"], ... dtypes={"text": "VARCHAR", "title": "VARCHAR"}, ... ) @@ -110,18 +110,19 @@ def create_documents( ... schema="bm25_tables", ... key="id", ... df=df, - ... fields=["title", "text"], + ... columns=["title", "text"], ... ) """ - if isinstance(fields, str): - fields = [fields] - if not dtypes: dtypes = {} - fields = ", ".join([f"{field} {dtypes.get(field, 'VARCHAR')}" for field in fields]) return _create_documents( - database=database, schema=schema, fields=fields, config=config + database=database, + schema=schema, + fields=", ".join( + [f"{field} {dtypes.get(field, 'VARCHAR')}" for field in columns] + ), + config=config, ) diff --git a/ducksearch/tables/insert.py b/ducksearch/tables/insert.py index f86e0e2..76fa30f 100644 --- a/ducksearch/tables/insert.py +++ b/ducksearch/tables/insert.py @@ -2,13 +2,11 @@ import os import shutil -import pandas as pd import pyarrow as pa import pyarrow.parquet as pq from joblib import Parallel, delayed from ..decorators import execute_with_duckdb -from ..hf import insert_documents as hf_insert_documents from ..utils import batchify from .create import ( create_documents, @@ -57,11 +55,16 @@ def write_parquet( """ documents_table = collections.defaultdict(list) + fields = set() for document in documents: - if key is not None: - documents_table[key].append(document[key]) + for field in document.keys(): + if field != "id": + fields.add(field) + + for document in documents: + documents_table["id"].append(document[key]) for field in fields: - documents_table[field].append(document.get(field, "")) + documents_table[field].append(document.get(field, None)) documents_path = os.path.join(".", "duckdb_tmp", "documents", f"{index}.parquet") documents_table = pa.Table.from_pydict(documents_table) @@ -78,7 +81,7 @@ def insert_documents( schema: str, df: list[dict] | str, key: str, - fields: list[str] | str, + columns: list[str] | str, dtypes: dict[str, str] | None = None, batch_size: int = 30_000, n_jobs: int = -1, @@ -97,7 +100,7 @@ def insert_documents( The list of document dictionaries or a string (URL) for a Hugging Face dataset to insert. key The field that uniquely identifies each document (e.g., 'id'). - fields + columns The list of document fields to insert. Can be a string if inserting a single field. dtypes Optional dictionary specifying the DuckDB type for each field. Defaults to 'VARCHAR' for all unspecified fields. @@ -122,33 +125,17 @@ def insert_documents( ... database="test.duckdb", ... schema="bm25_tables", ... key="id", - ... fields=["title", "text"], + ... columns=["title", "text"], ... df=df ... ) - """ - if isinstance(fields, str): - fields = [fields] - - fields = [field for field in fields if field != "id"] - - if isinstance(df, str): - return hf_insert_documents( - database=database, - schema=schema, - key=key, - fields=fields, - url=df, - config=config, - limit=limit, - ) - if isinstance(df, pd.DataFrame): - df = df.to_dict(orient="records") + """ + columns = [column for column in columns if column != "id"] create_documents( database=database, schema=schema, - fields=fields, + columns=columns, config=config, dtypes=dtypes, ) @@ -165,7 +152,7 @@ def insert_documents( delayed(function=write_parquet)( batch, index, - fields, + columns, key, ) for index, batch in enumerate( @@ -179,9 +166,9 @@ def insert_documents( parquet_files=os.path.join(documents_path, "*.parquet"), config=config, key_field=f"df.{key}", - fields=", ".join(fields), - df_fields=", ".join([f"df.{field}" for field in fields]), - src_fields=", ".join([f"src.{field}" for field in fields]), + fields=", ".join(columns), + df_fields=", ".join([f"df.{field}" for field in columns]), + src_fields=", ".join([f"src.{field}" for field in columns]), ) if os.path.exists(path=documents_path): diff --git a/ducksearch/tables/select.py b/ducksearch/tables/select.py index 33d8d34..f1a7534 100644 --- a/ducksearch/tables/select.py +++ b/ducksearch/tables/select.py @@ -1,3 +1,5 @@ +import pandas as pd + from ..decorators import execute_with_duckdb @@ -6,9 +8,44 @@ read_only=True, fetch_df=True, ) -def select_documents() -> list[dict]: +def _select_documents() -> list[dict]: + """Select all documents from the documents table. + + Returns + ------- + list[dict] + A list of dictionaries representing the documents. + + Examples + -------- + >>> from ducksearch import tables + + >>> documents = tables.select_documents( + ... database="test.duckdb", + ... schema="bm25_tables", + ... ) + + >>> assert len(documents) == 3 + """ + + +def select_documents( + database: str, + schema: str, + limit: int | None = None, + config: dict | None = None, +) -> list[dict]: """Select all documents from the documents table. + Parameters + ---------- + database + The name of the DuckDB database. + schema + The schema where the documents table is located. + config + Optional configuration options for the DuckDB connection. + Returns ------- list[dict] @@ -25,6 +62,14 @@ def select_documents() -> list[dict]: >>> assert len(documents) == 3 """ + return pd.DataFrame( + _select_documents( + database=database, + schema=schema, + limit="" if limit is None else f"LIMIT {limit}", + config=config, + ) + ) @execute_with_duckdb( diff --git a/ducksearch/tables/select/documents.sql b/ducksearch/tables/select/documents.sql index ff32ad9..b8031be 100644 --- a/ducksearch/tables/select/documents.sql +++ b/ducksearch/tables/select/documents.sql @@ -1,3 +1,4 @@ SELECT * FROM {schema}.documents -ORDER BY id ASC; +ORDER BY id ASC +{limit}; diff --git a/ducksearch/tables/update.py b/ducksearch/tables/update.py new file mode 100644 index 0000000..b864f22 --- /dev/null +++ b/ducksearch/tables/update.py @@ -0,0 +1,55 @@ +from ..decorators import execute_with_duckdb + + +@execute_with_duckdb( + relative_path="tables/update/documents.sql", +) +def _add_columns_documents() -> None: + """Add columns to the documents table in the DuckDB database. + + Parameters + ---------- + database: str + The name of the DuckDB database. + config: dict, optional + The configuration options for the DuckDB connection. + """ + + +def add_columns_documents( + database: str, + schema: str, + columns: list[str] | str, + dtypes: dict = None, + config: dict = None, +) -> None: + """Add columns to the documents table in the DuckDB database. + + Parameters + ---------- + database: + The name of the DuckDB database. + schema: + The schema in which the documents table is located. + columns: + The columns to add to the documents table. + dtypes: + The data types for the columns to add. + config: + The configuration options for the DuckDB connection. + + """ + if isinstance(columns, str): + columns = [columns] + + if dtypes is None: + dtypes = {} + + _add_columns_documents( + database=database, + schema=schema, + fields=", ".join( + [f"ADD COLUMN {field} {dtypes.get(field, 'VARCHAR')}" for field in columns] + ), + config=config, + ) diff --git a/ducksearch/tables/update/documents.sql b/ducksearch/tables/update/documents.sql new file mode 100644 index 0000000..07e2b51 --- /dev/null +++ b/ducksearch/tables/update/documents.sql @@ -0,0 +1,3 @@ +ALTER TABLE {schema}.documents + {fields} +; \ No newline at end of file diff --git a/ducksearch/upload/upload.py b/ducksearch/upload/upload.py index fce9881..741f336 100644 --- a/ducksearch/upload/upload.py +++ b/ducksearch/upload/upload.py @@ -1,5 +1,9 @@ +import pandas as pd + +from ..hf import insert_documents as hf_insert_documents from ..search import update_index_documents, update_index_queries from ..tables import ( + add_columns_documents, create_documents, create_documents_queries, create_queries, @@ -7,8 +11,9 @@ insert_documents, insert_documents_queries, insert_queries, + select_documents_columns, ) -from ..utils import plot +from ..utils import get_list_columns_df, plot def documents( @@ -28,6 +33,7 @@ def documents( dtypes: dict[str, str] | None = None, config: dict | None = None, limit: int | None = None, + tqdm_bar: bool = True, ) -> str: """Upload documents to DuckDB, create necessary schema, and index using BM25. @@ -62,6 +68,8 @@ def documents( Number of parallel jobs to use for uploading documents. Default use all available processors. config Optional configuration dictionary for the DuckDB connection and other settings. + tqdm_bar + Whether to display a progress bar when uploading documents Returns ------- @@ -71,8 +79,6 @@ def documents( """ schema = "bm25_tables" - fields = [field for field in fields if field != "id"] - create_schema( database=database, schema=schema, @@ -85,35 +91,72 @@ def documents( config=config, ) - create_documents( - database=database, - schema=schema, - dtypes=dtypes, - fields=fields, - config=config, + columns = get_list_columns_df( + documents=documents, ) - create_documents_queries( - database=database, - schema=schema, - config=config, - ) + if isinstance(documents, str): + hf_insert_documents( + database=database, + schema=schema, + key=key, + url=documents, + config=config, + limit=limit, + dtypes=dtypes, + ) + + else: + if isinstance(documents, pd.DataFrame): + documents = documents.to_dict(orient="records") + + create_documents( + database=database, + schema=schema, + dtypes=dtypes, + columns=columns, + config=config, + ) + + existing_columns = select_documents_columns( + database=database, + schema=schema, + config=config, + ) + + existing_columns = set(existing_columns) + columns_to_add = set(columns) - existing_columns + if columns_to_add: + add_columns_documents( + database=database, + schema=schema, + columns=list(columns_to_add), + dtypes=dtypes, + config=config, + ) - insert_documents( + insert_documents( + database=database, + schema=schema, + df=documents, + key=key, + columns=columns, + batch_size=batch_size, + dtypes=dtypes, + n_jobs=n_jobs, + config=config, + limit=limit, + ) + + create_documents_queries( database=database, schema=schema, - df=documents, - key=key, - fields=fields, - batch_size=batch_size, - dtypes=dtypes, - n_jobs=n_jobs, config=config, - limit=limit, ) update_index_documents( database=database, + fields=fields, b=b, k1=k1, stemmer=stemmer, diff --git a/ducksearch/utils/__init__.py b/ducksearch/utils/__init__.py index 6bccb10..a443c2d 100644 --- a/ducksearch/utils/__init__.py +++ b/ducksearch/utils/__init__.py @@ -1,4 +1,13 @@ from .batch import batchify +from .columns import get_list_columns_df +from .hash import generate_random_hash +from .parralel_tqdm import ParallelTqdm from .plot import plot -__all__ = ["batchify", "plot"] +__all__ = [ + "batchify", + "get_list_columns_df", + "generate_random_hash", + "plot", + "ParallelTqdm", +] diff --git a/ducksearch/utils/columns.py b/ducksearch/utils/columns.py new file mode 100644 index 0000000..fdae728 --- /dev/null +++ b/ducksearch/utils/columns.py @@ -0,0 +1,20 @@ +import pandas as pd + + +def get_list_columns_df( + documents: list[dict] | pd.DataFrame, +) -> list[str]: + """Get a list of columns from a list of dictionaries or a DataFrame.""" + columns = None + if isinstance(documents, pd.DataFrame): + return list(documents.columns) + + if isinstance(documents, list): + columns = set() + for document in documents: + for column in document.keys(): + if column != "id": + columns.add(column) + return list(columns) + + return None diff --git a/ducksearch/utils/hash.py b/ducksearch/utils/hash.py new file mode 100644 index 0000000..f1918be --- /dev/null +++ b/ducksearch/utils/hash.py @@ -0,0 +1,11 @@ +import hashlib +import secrets + + +def generate_random_hash() -> str: + """Generate a random SHA-256 hash.""" + random_data = secrets.token_bytes(32) + hash_obj = hashlib.sha256() + hash_obj.update(random_data) + random_hash = hash_obj.hexdigest() + return random_hash diff --git a/ducksearch/utils/parralel_tqdm.py b/ducksearch/utils/parralel_tqdm.py new file mode 100644 index 0000000..480a1a6 --- /dev/null +++ b/ducksearch/utils/parralel_tqdm.py @@ -0,0 +1,70 @@ +import tqdm +from joblib import Parallel + + +class ParallelTqdm(Parallel): + """joblib.Parallel, but with a tqdm progressbar. + + Parameters + ---------- + total : int + The total number of tasks to complete. + desc : str + A description of the task. + tqdm_bar : bool, optional + Whether to display a tqdm progress bar. Default is False. + show_joblib_header : bool, optional + Whether to display the joblib header. Default is False + + References + ---------- + https://github.com/joblib/joblib/issues/972 + """ + + def __init__( + self, + *, + total: int, + desc: str, + tqdm_bar: bool = True, + show_joblib_header: bool = False, + **kwargs, + ) -> None: + super().__init__(verbose=(1 if show_joblib_header else 0), **kwargs) + self.total = total + self.desc = desc + self.tqdm_bar = tqdm_bar + self.progress_bar: tqdm.tqdm | None = None + + def __call__(self, iterable): + try: + return super().__call__(iterable) + finally: + if self.progress_bar is not None: + self.progress_bar.close() + + __call__.__doc__ = Parallel.__call__.__doc__ + + def dispatch_one_batch(self, iterator): + """Dispatch a batch of tasks, and update the progress bar""" + if self.progress_bar is None and self.tqdm_bar: + self.progress_bar = tqdm.tqdm( + desc=self.desc, + total=self.total, + position=0, + disable=self.tqdm_bar, + unit="tasks", + ) + return super().dispatch_one_batch(iterator=iterator) + + dispatch_one_batch.__doc__ = Parallel.dispatch_one_batch.__doc__ + + def print_progress(self): + """Display the process of the parallel execution using tqdm""" + if self.total is None and self._original_iterator is None: + self.total = self.n_dispatched_tasks + self.progress_bar.total = self.total + self.progress_bar.refresh() + + if self.tqdm_bar: + self.progress_bar.update(self.n_completed_tasks - self.progress_bar.n)