diff --git a/chromadb/api/models/AsyncCollection.py b/chromadb/api/models/AsyncCollection.py index 61a00bdc6b6..350fa27af51 100644 --- a/chromadb/api/models/AsyncCollection.py +++ b/chromadb/api/models/AsyncCollection.py @@ -4,13 +4,14 @@ Union, cast, ) -import numpy as np +from numpy.typing import NDArray from chromadb.api.types import ( URI, CollectionMetadata, Embedding, Embeddings, + EmbeddingDType, Include, Metadata, Document, @@ -37,7 +38,7 @@ async def add( embeddings: Optional[ Union[ OneOrMany[Embedding], - OneOrMany[np.ndarray], + OneOrMany[NDArray[EmbeddingDType]], ] ] = None, metadatas: Optional[OneOrMany[Metadata]] = None, @@ -99,7 +100,7 @@ async def get( limit: Optional[int] = None, offset: Optional[int] = None, where_document: Optional[WhereDocument] = None, - include: Include = ["metadatas", "documents"], + include: Include = ["metadatas", "documents"], # type: ignore[list-item] ) -> GetResult: """Get embeddings and their associate data from the data store. If no ids or where filter is provided returns all embeddings up to limit starting at offset. @@ -152,7 +153,7 @@ async def query( query_embeddings: Optional[ Union[ OneOrMany[Embedding], - OneOrMany[np.ndarray], + OneOrMany[NDArray[EmbeddingDType]], ] ] = None, query_texts: Optional[OneOrMany[Document]] = None, @@ -161,7 +162,7 @@ async def query( n_results: int = 10, where: Optional[Where] = None, where_document: Optional[WhereDocument] = None, - include: Include = ["metadatas", "documents", "distances"], + include: Include = ["metadatas", "documents", "distances"], # type: ignore[list-item] ) -> QueryResult: """Get the n_results nearest neighbor embeddings for provided query_embeddings or query_texts. @@ -240,7 +241,7 @@ async def update( embeddings: Optional[ Union[ OneOrMany[Embedding], - OneOrMany[np.ndarray], + OneOrMany[NDArray[EmbeddingDType]], ] ] = None, metadatas: Optional[OneOrMany[Metadata]] = None, @@ -284,7 +285,7 @@ async def upsert( embeddings: Optional[ Union[ OneOrMany[Embedding], - OneOrMany[np.ndarray], + OneOrMany[NDArray[EmbeddingDType]], ] ] = None, metadatas: Optional[OneOrMany[Metadata]] = None, diff --git a/chromadb/api/models/Collection.py b/chromadb/api/models/Collection.py index 336f2633d97..dc1aeb0dfef 100644 --- a/chromadb/api/models/Collection.py +++ b/chromadb/api/models/Collection.py @@ -1,5 +1,6 @@ from typing import TYPE_CHECKING, Optional, Union, cast import numpy as np +from numpy.typing import NDArray from chromadb.api.models.CollectionCommon import CollectionCommon from chromadb.api.types import ( @@ -7,6 +8,7 @@ CollectionMetadata, Embedding, Embeddings, + EmbeddingDType, Include, Metadata, Document, @@ -41,10 +43,10 @@ def count(self) -> int: def add( self, ids: OneOrMany[ID], - embeddings: Optional[ # type: ignore[type-arg] + embeddings: Optional[ Union[ OneOrMany[Embedding], - OneOrMany[np.ndarray], + OneOrMany[NDArray[EmbeddingDType]], ] ] = None, metadatas: Optional[OneOrMany[Metadata]] = None, @@ -97,7 +99,7 @@ def get( limit: Optional[int] = None, offset: Optional[int] = None, where_document: Optional[WhereDocument] = None, - include: Include = ["metadatas", "documents"], + include: Include = ["metadatas", "documents"], # type: ignore[list-item] ) -> GetResult: """Get embeddings and their associate data from the data store. If no ids or where filter is provided returns all embeddings up to limit starting at offset. @@ -159,7 +161,7 @@ def query( n_results: int = 10, where: Optional[Where] = None, where_document: Optional[WhereDocument] = None, - include: Include = ["metadatas", "documents", "distances"], + include: Include = ["metadatas", "documents", "distances"], # type: ignore[list-item] ) -> QueryResult: """Get the n_results nearest neighbor embeddings for provided query_embeddings or query_texts. diff --git a/chromadb/api/models/CollectionCommon.py b/chromadb/api/models/CollectionCommon.py index a5395920fc8..bc16e4d07f5 100644 --- a/chromadb/api/models/CollectionCommon.py +++ b/chromadb/api/models/CollectionCommon.py @@ -10,6 +10,7 @@ cast, ) import numpy as np +from numpy.typing import NDArray from uuid import UUID import chromadb.utils.embedding_functions as ef @@ -19,6 +20,7 @@ DataLoader, Embedding, Embeddings, + EmbeddingDType, Embeddable, RecordSet, GetResult, @@ -152,10 +154,10 @@ def get_model(self) -> CollectionModel: def _unpack_embedding_set( self, ids: OneOrMany[ID], - embeddings: Optional[ # type: ignore[type-arg] + embeddings: Optional[ Union[ OneOrMany[Embedding], - OneOrMany[np.ndarray], + OneOrMany[NDArray[EmbeddingDType]], ] ], metadatas: Optional[OneOrMany[Metadata]], @@ -340,7 +342,7 @@ def _validate_and_prepare_query_request( valid_query_embeddings = ( validate_embeddings( self._normalize_embeddings( - maybe_cast_one_to_many_embedding(query_embeddings) # type: ignore[type-arg] + maybe_cast_one_to_many_embedding(query_embeddings) # type: ignore[arg-type] ) ) if query_embeddings is not None @@ -429,7 +431,7 @@ def _process_add_request( embeddings: Optional[ Union[ OneOrMany[Embedding], - OneOrMany[np.ndarray], + OneOrMany[NDArray[EmbeddingDType]], ] ], metadatas: Optional[OneOrMany[Metadata]], diff --git a/chromadb/api/types.py b/chromadb/api/types.py index 34252543ffb..3d5a6a7225b 100644 --- a/chromadb/api/types.py +++ b/chromadb/api/types.py @@ -249,12 +249,7 @@ def __call__(self: EmbeddingFunction[D], input: D) -> Embeddings: converted = maybe_cast_one_to_many_embedding(result) - if converted is None: - raise ValueError( - "Expected embeddings not to be None" - ) - - return validate_embeddings(converted) + return validate_embeddings(converted) # type: ignore[arg-type] setattr(cls, "__call__", __call__)