From 63ae52baad0e63bef1f8e1aadbd4916db2707921 Mon Sep 17 00:00:00 2001 From: Spike Lu Date: Wed, 11 Sep 2024 15:57:31 -0700 Subject: [PATCH] paying for the sins of our fathers --- chromadb/api/models/CollectionCommon.py | 253 ++++++---------------- chromadb/api/types.py | 81 ++++--- chromadb/test/api/test_api_update.py | 2 +- chromadb/test/api/test_validations.py | 65 +++++- chromadb/test/property/test_embeddings.py | 16 +- 5 files changed, 191 insertions(+), 226 deletions(-) diff --git a/chromadb/api/models/CollectionCommon.py b/chromadb/api/models/CollectionCommon.py index 0175156e3ba..9aaf2a16c4e 100644 --- a/chromadb/api/models/CollectionCommon.py +++ b/chromadb/api/models/CollectionCommon.py @@ -25,7 +25,6 @@ Include, Loadable, Metadata, - Metadatas, Document, Documents, Image, @@ -47,7 +46,7 @@ validate_n_results, validate_where, validate_where_document, - does_record_set_contain_any_data, + record_set_contains_one_of, ) # TODO: We should rename the types in chromadb.types to be Models where @@ -146,105 +145,54 @@ def __repr__(self) -> str: def get_model(self) -> CollectionModel: return self._model + @staticmethod def _unpack_record_set( - self, ids: OneOrMany[ID], - embeddings: Optional[ # type: ignore[type-arg] - Union[ - OneOrMany[Embedding], - OneOrMany[np.ndarray], - ] - ], - metadatas: Optional[OneOrMany[Metadata]], - documents: Optional[OneOrMany[Document]], + embeddings: Optional[Union[OneOrMany[Embedding], OneOrMany[np.ndarray]]] = None, # type: ignore[type-arg] + metadatas: Optional[OneOrMany[Metadata]] = None, + documents: Optional[OneOrMany[Document]] = None, images: Optional[OneOrMany[Image]] = None, uris: Optional[OneOrMany[URI]] = None, ) -> RecordSet: - unpacked_ids = maybe_cast_one_to_many(ids) - unpacked_embeddings = maybe_cast_one_to_many_embedding(embeddings) - unpacked_metadatas = maybe_cast_one_to_many(metadatas) - unpacked_documents = maybe_cast_one_to_many(documents) - unpacked_images = maybe_cast_one_to_many(images) - unpacked_uris = maybe_cast_one_to_many(uris) return { - "ids": cast(IDs, unpacked_ids), - "embeddings": unpacked_embeddings, - "metadatas": unpacked_metadatas, - "documents": unpacked_documents, - "images": unpacked_images, - "uris": unpacked_uris, + "ids": cast(IDs, maybe_cast_one_to_many(ids)), + "embeddings": maybe_cast_one_to_many_embedding(embeddings), + "metadatas": maybe_cast_one_to_many(metadatas), + "documents": maybe_cast_one_to_many(documents), + "images": maybe_cast_one_to_many(images), + "uris": maybe_cast_one_to_many(uris), } + @staticmethod def _validate_record_set( - self, - ids: IDs, - embeddings: Optional[Embeddings], - metadatas: Optional[Metadatas], - documents: Optional[Documents], - images: Optional[Images], - uris: Optional[URIs], - require_embeddings_or_data: bool = True, + record_set: RecordSet, + require_data: bool, ) -> None: - valid_ids = validate_ids(ids) - valid_embeddings = ( - validate_embeddings(embeddings) if embeddings is not None else None - ) - valid_metadatas = ( - validate_metadatas(metadatas) if metadatas is not None else None - ) - - # No additional validation needed for documents, images, or uris - valid_documents = documents - valid_images = images - valid_uris = uris - - # Check that one of embeddings or ducuments or images is provided - if require_embeddings_or_data: - if ( - valid_embeddings is None - and valid_documents is None - and valid_images is None - and valid_uris is None - ): - raise ValueError( - "You must provide embeddings, documents, images, or uris." - ) - else: - # will replace this with does_record_set_contain_any_data in the following PR - if ( - valid_embeddings is None - and valid_documents is None - and valid_images is None - and valid_uris is None - and valid_metadatas is None - ): - raise ValueError("You must provide either data or metadatas.") + validate_ids(record_set["ids"]) + validate_embeddings(record_set["embeddings"]) if record_set[ + "embeddings" + ] is not None else None + validate_metadatas(record_set["metadatas"]) if record_set[ + "metadatas" + ] is not None else None # Only one of documents or images can be provided - if valid_documents is not None and valid_images is not None: + if record_set["documents"] is not None and record_set["images"] is not None: raise ValueError("You can only provide documents or images, not both.") - # Check that, if they're provided, the lengths of the arrays match the length of ids - if valid_embeddings is not None and len(valid_embeddings) != len(valid_ids): - raise ValueError( - f"Number of embeddings {len(valid_embeddings)} must match number of ids {len(valid_ids)}" - ) - if valid_metadatas is not None and len(valid_metadatas) != len(valid_ids): - raise ValueError( - f"Number of metadatas {len(valid_metadatas)} must match number of ids {len(valid_ids)}" - ) - if valid_documents is not None and len(valid_documents) != len(valid_ids): - raise ValueError( - f"Number of documents {len(valid_documents)} must match number of ids {len(valid_ids)}" - ) - if valid_images is not None and len(valid_images) != len(valid_ids): - raise ValueError( - f"Number of images {len(valid_images)} must match number of ids {len(valid_ids)}" - ) - if valid_uris is not None and len(valid_uris) != len(valid_ids): - raise ValueError( - f"Number of uris {len(valid_uris)} must match number of ids {len(valid_ids)}" - ) + required_fields: Include = ["embeddings", "documents", "images", "uris"] # type: ignore[list-item] + if not require_data: + required_fields += ["metadatas"] # type: ignore[list-item] + + if not record_set_contains_one_of(record_set, include=required_fields): + raise ValueError(f"You must provide one of {required_fields}") + + valid_ids = record_set["ids"] + for key in ["embeddings", "metadatas", "documents", "images", "uris"]: + if record_set[key] is not None and len(record_set[key]) != len(valid_ids): # type: ignore[literal-required] + raise ValueError( + f"Number of {key} {len(record_set[key])} must match number of ids {len(valid_ids)}" # type: ignore[literal-required] + ) def _compute_embeddings( self, @@ -362,14 +310,7 @@ def _validate_and_prepare_query_request( valid_include = validate_include(include, allow_distances=True) valid_n_results = validate_n_results(n_results) - embeddings_to_normalize = maybe_cast_one_to_many_embedding(query_embeddings) - normalized_embeddings = ( - self._normalize_embeddings(embeddings_to_normalize) - if embeddings_to_normalize is not None - else None - ) - - valid_query_embeddings = None + normalized_embeddings = maybe_cast_one_to_many_embedding(query_embeddings) if normalized_embeddings is not None: valid_query_embeddings = validate_embeddings(normalized_embeddings) else: @@ -436,7 +377,7 @@ def _process_add_request( images: Optional[OneOrMany[Image]] = None, uris: Optional[OneOrMany[URI]] = None, ) -> RecordSet: - unpacked_record_set = self._unpack_record_set( + record_set = self._unpack_record_set( ids=ids, embeddings=embeddings, metadatas=metadatas, @@ -445,39 +386,19 @@ def _process_add_request( uris=uris, ) - normalized_embeddings = ( - self._normalize_embeddings(unpacked_record_set["embeddings"]) - if unpacked_record_set["embeddings"] is not None - else None - ) - self._validate_record_set( - ids=unpacked_record_set["ids"], - embeddings=normalized_embeddings, - metadatas=unpacked_record_set["metadatas"], - documents=unpacked_record_set["documents"], - images=unpacked_record_set["images"], - uris=unpacked_record_set["uris"], + record_set, + require_data=True, ) - prepared_embeddings = ( - self._compute_embeddings( - documents=unpacked_record_set["documents"], - images=unpacked_record_set["images"], - uris=unpacked_record_set["uris"], + if record_set["embeddings"] is None: + record_set["embeddings"] = self._compute_embeddings( + documents=record_set["documents"], + images=record_set["images"], + uris=record_set["uris"], ) - if normalized_embeddings is None - else normalized_embeddings - ) - return { - "ids": unpacked_record_set["ids"], - "embeddings": prepared_embeddings, - "metadatas": unpacked_record_set["metadatas"], - "documents": unpacked_record_set["documents"], - "images": unpacked_record_set["images"], - "uris": unpacked_record_set["uris"], - } + return record_set def _process_upsert_request( self, @@ -493,7 +414,7 @@ def _process_upsert_request( images: Optional[OneOrMany[Image]] = None, uris: Optional[OneOrMany[URI]] = None, ) -> RecordSet: - unpacked_record_set = self._unpack_record_set( + record_set = self._unpack_record_set( ids=ids, embeddings=embeddings, metadatas=metadatas, @@ -502,37 +423,20 @@ def _process_upsert_request( uris=uris, ) - normalized_embeddings = ( - self._normalize_embeddings(unpacked_record_set["embeddings"]) - if unpacked_record_set["embeddings"] is not None - else None - ) - self._validate_record_set( - ids=unpacked_record_set["ids"], - embeddings=normalized_embeddings, - metadatas=unpacked_record_set["metadatas"], - documents=unpacked_record_set["documents"], - images=unpacked_record_set["images"], - uris=unpacked_record_set["uris"], + record_set, + require_data=True, ) - prepared_embeddings = normalized_embeddings - if prepared_embeddings is None: - prepared_embeddings = self._compute_embeddings( - documents=unpacked_record_set["documents"], - images=unpacked_record_set["images"], + # TODO: Correctly handle Upsert for URIs + if record_set["embeddings"] is None: + record_set["embeddings"] = self._compute_embeddings( + documents=record_set["documents"], + images=record_set["images"], uris=None, ) - return { - "ids": unpacked_record_set["ids"], - "embeddings": prepared_embeddings, - "metadatas": unpacked_record_set["metadatas"], - "documents": unpacked_record_set["documents"], - "images": unpacked_record_set["images"], - "uris": unpacked_record_set["uris"], - } + return record_set def _process_update_request( self, @@ -548,7 +452,7 @@ def _process_update_request( images: Optional[OneOrMany[Image]] = None, uris: Optional[OneOrMany[URI]] = None, ) -> RecordSet: - unpacked_record_set = self._unpack_record_set( + record_set = self._unpack_record_set( ids=ids, embeddings=embeddings, metadatas=metadatas, @@ -557,40 +461,22 @@ def _process_update_request( uris=uris, ) - normalized_embeddings = ( - self._normalize_embeddings(unpacked_record_set["embeddings"]) - if unpacked_record_set["embeddings"] is not None - else None - ) - self._validate_record_set( - ids=unpacked_record_set["ids"], - embeddings=normalized_embeddings, - metadatas=unpacked_record_set["metadatas"], - documents=unpacked_record_set["documents"], - images=unpacked_record_set["images"], - uris=unpacked_record_set["uris"], - require_embeddings_or_data=False, + record_set, + require_data=False, ) - prepared_embeddings = normalized_embeddings - if prepared_embeddings is None and does_record_set_contain_any_data( - unpacked_record_set, include=["documents", "images"] + # TODO: Correctly handle Update for URIs + if record_set["embeddings"] is None and record_set_contains_one_of( + record_set, include=["documents", "images"] # type: ignore[list-item] ): - prepared_embeddings = self._compute_embeddings( - documents=unpacked_record_set["documents"], - images=unpacked_record_set["images"], + record_set["embeddings"] = self._compute_embeddings( + documents=record_set["documents"], + images=record_set["images"], uris=None, ) - return { - "ids": unpacked_record_set["ids"], - "embeddings": prepared_embeddings, - "metadatas": unpacked_record_set["metadatas"], - "documents": unpacked_record_set["documents"], - "images": unpacked_record_set["images"], - "uris": unpacked_record_set["uris"], - } + return record_set def _validate_and_prepare_delete_request( self, @@ -606,17 +492,6 @@ def _validate_and_prepare_delete_request( return (ids, where, where_document) - @staticmethod - def _normalize_embeddings( - embeddings: Union[ # type: ignore[type-arg] - OneOrMany[Embedding], - OneOrMany[np.ndarray], - ] - ) -> Embeddings: - if isinstance(embeddings, np.ndarray): - return embeddings.tolist() # type: ignore - return embeddings # type: ignore - def _embed(self, input: Any) -> Embeddings: if self._embedding_function is None: raise ValueError( diff --git a/chromadb/api/types.py b/chromadb/api/types.py index e0e2704e2f8..0d681812d25 100644 --- a/chromadb/api/types.py +++ b/chromadb/api/types.py @@ -54,17 +54,45 @@ def maybe_cast_one_to_many(target: Optional[(OneOrMany[T])]) -> Optional[List[T] def maybe_cast_one_to_many_embedding( - target: Optional[Union[OneOrMany[Embedding], OneOrMany[OneOrMany[np.ndarray]]]], # type: ignore[type-arg] + target: Optional[Union[OneOrMany[Embedding], OneOrMany[np.ndarray]]], # type: ignore[type-arg] ) -> Optional[Embeddings]: # No target if target is None: return None - if isinstance(target, List): - # One Embedding + if isinstance(target, np.ndarray): + dim = target.ndim + if dim == 1: + # TODO: Remove this conversion when unpacking + return cast(Embeddings, [target.tolist()]) + if dim == 2: + return cast(Embeddings, target.tolist()) + raise ValueError( + f"Expected embeddings to be a 1D or 2D numpy array, got {dim}D" + ) + + if isinstance(target, list): + if len(target) == 0: + raise ValueError( + "Expected embeddings to be a list or a numpy array with at least one item" + ) + + # target represents a single embedding as a list if isinstance(target[0], (int, float)): return cast(Embeddings, [target]) - # Already a sequence + + # Check if the first item is a numpy array - target is a list of numpy arrays + if isinstance(target[0], np.ndarray): + # Check all the embeddings are 1D + for embedding in target: + dim = (cast(np.ndarray, embedding)).ndim # type: ignore[type-arg] + if dim != 1: + raise ValueError( + f"Expected embeddings to be a list of 1D numpy arrays, got a {dim}D numpy array" + ) + return [cast(np.ndarray, embedding).tolist() for embedding in target] # type: ignore[type-arg] + + # target is a list of lists representing embeddings return cast(Embeddings, target) @@ -112,7 +140,21 @@ class IncludeEnum(str, Enum): data = "data" -# Record set +# This should ust be List[Literal["documents", "embeddings", "metadatas", "distances"]] +# However, this provokes an incompatibility with the Overrides library and Python 3.7 +Include = List[IncludeEnum] +IncludeMetadataDocuments = Field(default=["metadatas", "documents"]) +IncludeMetadataDocumentsEmbeddings = Field( + default=["metadatas", "documents", "embeddings"] +) +IncludeMetadataDocumentsEmbeddingsDistances = Field( + default=["metadatas", "documents", "embeddings", "distances"] +) +IncludeMetadataDocumentsDistances = Field( + default=["metadatas", "documents", "distances"] +) + + class RecordSet(TypedDict): ids: IDs embeddings: Optional[Embeddings] @@ -122,47 +164,34 @@ class RecordSet(TypedDict): uris: Optional[URIs] -def does_record_set_contain_any_data(record_set: RecordSet, include: List[str]) -> bool: +def record_set_contains_one_of(record_set: RecordSet, include: Include) -> bool: + """Check if the record set contains data for any of the given include keys""" if len(include) == 0: raise ValueError("Expected include to be a non-empty list") error_messages = [] - for key in include: - if key not in record_set: + for include_key in include: + if include_key not in record_set: error_messages.append( - f"Expected include key to be a a known field of RecordSet, got {key}" + f"Expected include key to be a a known field of RecordSet, got {include_key}" ) if len(error_messages) > 0: raise ValueError(", ".join(error_messages)) - for key, value in record_set.items(): - if key not in include: + for record_key, value in record_set.items(): + if record_key not in include: continue if isinstance(value, list): if len(value) == 0: - raise ValueError(f"Expected {key} to be a non-empty list") + raise ValueError(f"Expected {record_key} to be a non-empty list") return True return False -# This should ust be List[Literal["documents", "embeddings", "metadatas", "distances"]] -# However, this provokes an incompatibility with the Overrides library and Python 3.7 -Include = List[IncludeEnum] -IncludeMetadataDocuments = Field(default=["metadatas", "documents"]) -IncludeMetadataDocumentsEmbeddings = Field( - default=["metadatas", "documents", "embeddings"] -) -IncludeMetadataDocumentsEmbeddingsDistances = Field( - default=["metadatas", "documents", "embeddings", "distances"] -) -IncludeMetadataDocumentsDistances = Field( - default=["metadatas", "documents", "distances"] -) - # Re-export types from chromadb.types LiteralValue = LiteralValue LogicalOperator = LogicalOperator diff --git a/chromadb/test/api/test_api_update.py b/chromadb/test/api/test_api_update.py index a83660fb069..a3d510e701f 100644 --- a/chromadb/test/api/test_api_update.py +++ b/chromadb/test/api/test_api_update.py @@ -2,7 +2,7 @@ from chromadb.api import ClientAPI -def test_update_query(client: ClientAPI) -> None: +def test_update_query_with_none_data(client: ClientAPI) -> None: client.reset() collection = client.create_collection("test_update_query") diff --git a/chromadb/test/api/test_validations.py b/chromadb/test/api/test_validations.py index 528a361ab60..3264470c68f 100644 --- a/chromadb/test/api/test_validations.py +++ b/chromadb/test/api/test_validations.py @@ -1,5 +1,10 @@ import pytest -from chromadb.api.types import RecordSet, does_record_set_contain_any_data +import numpy as np +from chromadb.api.types import ( + RecordSet, + record_set_contains_one_of, + maybe_cast_one_to_many_embedding, +) def test_does_record_set_contain_any_data() -> None: @@ -23,23 +28,71 @@ def test_does_record_set_contain_any_data() -> None: } with pytest.raises(ValueError) as e: - does_record_set_contain_any_data(record_set_non_list, include=["embeddings"]) + record_set_contains_one_of(record_set_non_list, include=["embeddings"]) # type: ignore[list-item] assert "Expected embeddings to be a non-empty list" in str(e) # Test case 2: Non-list field with pytest.raises(ValueError) as e: - does_record_set_contain_any_data(valid_record_set, include=[]) + record_set_contains_one_of(valid_record_set, include=[]) assert "Expected include to be a non-empty list" in str(e) # Test case 3: Non-existent field with pytest.raises(ValueError) as e: - does_record_set_contain_any_data( - valid_record_set, include=["non_existent_field"] - ) + record_set_contains_one_of(valid_record_set, include=["non_existent_field"]) # type: ignore[list-item] assert ( "Expected include key to be a a known field of RecordSet, got non_existent_field" in str(e) ) + + +def test_maybe_cast_one_to_many_embedding() -> None: + # Test with None input + assert maybe_cast_one_to_many_embedding(None) is None + + # Test with a single embedding as a list + single_embedding = [1.0, 2.0, 3.0] + result = maybe_cast_one_to_many_embedding(single_embedding) + assert result == [single_embedding] + + # Test with multiple embeddings as a list of lists + multiple_embeddings = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] + result = maybe_cast_one_to_many_embedding(multiple_embeddings) # type: ignore[arg-type] + assert result == multiple_embeddings + + # Test with a numpy array (single embedding) + np_single = np.array([1.0, 2.0, 3.0]) + result = maybe_cast_one_to_many_embedding(np_single) + assert isinstance(result, list) + assert len(result) == 1 + assert np.array_equal(result[0], np_single) + + # Test with a numpy array (multiple embeddings) + np_multiple = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + result = maybe_cast_one_to_many_embedding(np_multiple) + assert isinstance(result, list) + assert len(result) == 2 + assert np.array_equal(result, np_multiple) + + # Test with an empty list (should raise ValueError) + with pytest.raises( + ValueError, + match="Expected embeddings to be a list or a numpy array with at least one item", + ): + maybe_cast_one_to_many_embedding([]) + + # Test with an empty list (should raise ValueError) + with pytest.raises( + ValueError, + match="Expected embeddings to be a list or a numpy array with at least one item", + ): + maybe_cast_one_to_many_embedding(np.array([])) + + # Test with an empty str (should raise ValueError) + with pytest.raises( + ValueError, + match="Expected embeddings to be a list or a numpy array, got str", + ): + maybe_cast_one_to_many_embedding("") # type: ignore[arg-type] diff --git a/chromadb/test/property/test_embeddings.py b/chromadb/test/property/test_embeddings.py index dc53bbc52d7..c3b2ae0558d 100644 --- a/chromadb/test/property/test_embeddings.py +++ b/chromadb/test/property/test_embeddings.py @@ -8,7 +8,14 @@ from hypothesis import given, settings, HealthCheck from typing import Dict, Set, cast, Union, DefaultDict, Any, List from dataclasses import dataclass -from chromadb.api.types import ID, Embeddings, Include, IDs, validate_embeddings +from chromadb.api.types import ( + ID, + Embeddings, + Include, + IDs, + validate_embeddings, + maybe_cast_one_to_many_embedding, +) from chromadb.config import System import chromadb.errors as errors from chromadb.api import ClientAPI @@ -796,7 +803,8 @@ def test_autocasting_validate_embeddings_for_compatible_types( supported_types: List[Any], ) -> None: embds = strategies.create_embeddings(10, 10, supported_types) - validated_embeddings = validate_embeddings(Collection._normalize_embeddings(embds)) + + validated_embeddings = validate_embeddings(maybe_cast_one_to_many_embedding(embds)) # type: ignore[arg-type] assert all( [ isinstance(value, list) @@ -816,7 +824,7 @@ def test_autocasting_validate_embeddings_with_ndarray( supported_types: List[Any], ) -> None: embds = strategies.create_embeddings_ndarray(10, 10, supported_types) - validated_embeddings = validate_embeddings(Collection._normalize_embeddings(embds)) + validated_embeddings = validate_embeddings(maybe_cast_one_to_many_embedding(embds)) # type: ignore[arg-type] assert all( [ isinstance(value, list) @@ -837,7 +845,7 @@ def test_autocasting_validate_embeddings_incompatible_types( ) -> None: embds = strategies.create_embeddings(10, 10, unsupported_types) with pytest.raises(ValueError) as e: - validate_embeddings(Collection._normalize_embeddings(embds)) + validate_embeddings(maybe_cast_one_to_many_embedding(embds)) # type: ignore[arg-type] assert "Expected each value in the embedding to be a int or float" in str(e)