Skip to content

Commit

Permalink
fix types
Browse files Browse the repository at this point in the history
  • Loading branch information
spikechroma committed Aug 26, 2024
1 parent f4fd6b4 commit d7d276a
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 21 deletions.
15 changes: 8 additions & 7 deletions chromadb/api/models/AsyncCollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
Union,
cast,
)
import numpy as np
from numpy.typing import NDArray

from chromadb.api.types import (
URI,
CollectionMetadata,
Embedding,
Embeddings,
EmbeddingDType,
Include,
Metadata,
Document,
Expand All @@ -37,7 +38,7 @@ async def add(
embeddings: Optional[
Union[
OneOrMany[Embedding],
OneOrMany[np.ndarray],
OneOrMany[NDArray[EmbeddingDType]],
]
] = None,
metadatas: Optional[OneOrMany[Metadata]] = None,
Expand Down Expand Up @@ -99,7 +100,7 @@ async def get(
limit: Optional[int] = None,
offset: Optional[int] = None,
where_document: Optional[WhereDocument] = None,
include: Include = ["metadatas", "documents"],
include: Include = ["metadatas", "documents"], # type: ignore[list-item]
) -> GetResult:
"""Get embeddings and their associate data from the data store. If no ids or where filter is provided returns
all embeddings up to limit starting at offset.
Expand Down Expand Up @@ -152,7 +153,7 @@ async def query(
query_embeddings: Optional[
Union[
OneOrMany[Embedding],
OneOrMany[np.ndarray],
OneOrMany[NDArray[EmbeddingDType]],
]
] = None,
query_texts: Optional[OneOrMany[Document]] = None,
Expand All @@ -161,7 +162,7 @@ async def query(
n_results: int = 10,
where: Optional[Where] = None,
where_document: Optional[WhereDocument] = None,
include: Include = ["metadatas", "documents", "distances"],
include: Include = ["metadatas", "documents", "distances"], # type: ignore[list-item]
) -> QueryResult:
"""Get the n_results nearest neighbor embeddings for provided query_embeddings or query_texts.
Expand Down Expand Up @@ -240,7 +241,7 @@ async def update(
embeddings: Optional[
Union[
OneOrMany[Embedding],
OneOrMany[np.ndarray],
OneOrMany[NDArray[EmbeddingDType]],
]
] = None,
metadatas: Optional[OneOrMany[Metadata]] = None,
Expand Down Expand Up @@ -284,7 +285,7 @@ async def upsert(
embeddings: Optional[
Union[
OneOrMany[Embedding],
OneOrMany[np.ndarray],
OneOrMany[NDArray[EmbeddingDType]],
]
] = None,
metadatas: Optional[OneOrMany[Metadata]] = None,
Expand Down
10 changes: 6 additions & 4 deletions chromadb/api/models/Collection.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from typing import TYPE_CHECKING, Optional, Union, cast
import numpy as np
from numpy.typing import NDArray

from chromadb.api.models.CollectionCommon import CollectionCommon
from chromadb.api.types import (
URI,
CollectionMetadata,
Embedding,
Embeddings,
EmbeddingDType,
Include,
Metadata,
Document,
Expand Down Expand Up @@ -41,10 +43,10 @@ def count(self) -> int:
def add(
self,
ids: OneOrMany[ID],
embeddings: Optional[ # type: ignore[type-arg]
embeddings: Optional[
Union[
OneOrMany[Embedding],
OneOrMany[np.ndarray],
OneOrMany[NDArray[EmbeddingDType]],
]
] = None,
metadatas: Optional[OneOrMany[Metadata]] = None,
Expand Down Expand Up @@ -97,7 +99,7 @@ def get(
limit: Optional[int] = None,
offset: Optional[int] = None,
where_document: Optional[WhereDocument] = None,
include: Include = ["metadatas", "documents"],
include: Include = ["metadatas", "documents"], # type: ignore[list-item]
) -> GetResult:
"""Get embeddings and their associate data from the data store. If no ids or where filter is provided returns
all embeddings up to limit starting at offset.
Expand Down Expand Up @@ -159,7 +161,7 @@ def query(
n_results: int = 10,
where: Optional[Where] = None,
where_document: Optional[WhereDocument] = None,
include: Include = ["metadatas", "documents", "distances"],
include: Include = ["metadatas", "documents", "distances"], # type: ignore[list-item]
) -> QueryResult:
"""Get the n_results nearest neighbor embeddings for provided query_embeddings or query_texts.
Expand Down
10 changes: 6 additions & 4 deletions chromadb/api/models/CollectionCommon.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
cast,
)
import numpy as np
from numpy.typing import NDArray
from uuid import UUID

import chromadb.utils.embedding_functions as ef
Expand All @@ -19,6 +20,7 @@
DataLoader,
Embedding,
Embeddings,
EmbeddingDType,
Embeddable,
RecordSet,
GetResult,
Expand Down Expand Up @@ -152,10 +154,10 @@ def get_model(self) -> CollectionModel:
def _unpack_embedding_set(
self,
ids: OneOrMany[ID],
embeddings: Optional[ # type: ignore[type-arg]
embeddings: Optional[
Union[
OneOrMany[Embedding],
OneOrMany[np.ndarray],
OneOrMany[NDArray[EmbeddingDType]],
]
],
metadatas: Optional[OneOrMany[Metadata]],
Expand Down Expand Up @@ -340,7 +342,7 @@ def _validate_and_prepare_query_request(
valid_query_embeddings = (
validate_embeddings(
self._normalize_embeddings(
maybe_cast_one_to_many_embedding(query_embeddings) # type: ignore[type-arg]
maybe_cast_one_to_many_embedding(query_embeddings) # type: ignore[arg-type]
)
)
if query_embeddings is not None
Expand Down Expand Up @@ -429,7 +431,7 @@ def _process_add_request(
embeddings: Optional[
Union[
OneOrMany[Embedding],
OneOrMany[np.ndarray],
OneOrMany[NDArray[EmbeddingDType]],
]
],
metadatas: Optional[OneOrMany[Metadata]],
Expand Down
7 changes: 1 addition & 6 deletions chromadb/api/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,12 +249,7 @@ def __call__(self: EmbeddingFunction[D], input: D) -> Embeddings:

converted = maybe_cast_one_to_many_embedding(result)

if converted is None:
raise ValueError(
"Expected embeddings not to be None"
)

return validate_embeddings(converted)
return validate_embeddings(converted) # type: ignore[arg-type]

setattr(cls, "__call__", __call__)

Expand Down

0 comments on commit d7d276a

Please sign in to comment.