Skip to content

Commit

Permalink
refactor code to extract out unpacking embeddings set from existing v…
Browse files Browse the repository at this point in the history
…alidation logic
  • Loading branch information
Spike Lu authored and spikechroma committed Aug 27, 2024
1 parent bc12ba7 commit 88fde75
Show file tree
Hide file tree
Showing 4 changed files with 244 additions and 214 deletions.
62 changes: 39 additions & 23 deletions chromadb/api/models/AsyncCollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@
TYPE_CHECKING,
Optional,
Union,
cast,
)
import numpy as np
from numpy.typing import NDArray

from chromadb.api.types import (
URI,
CollectionMetadata,
Embedding,
Embeddings,
EmbeddingDType,
Include,
Metadata,
Document,
Expand All @@ -35,7 +38,7 @@ async def add(
embeddings: Optional[
Union[
OneOrMany[Embedding],
OneOrMany[np.ndarray],
OneOrMany[NDArray[EmbeddingDType]],
]
] = None,
metadatas: Optional[OneOrMany[Metadata]] = None,
Expand Down Expand Up @@ -63,17 +66,23 @@ async def add(
ValueError: If you provide an id that already exists
"""
(
record_set = self._process_add_request(
ids,
embeddings,
metadatas,
documents,
images,
uris,
) = self._validate_and_prepare_embedding_set(
ids, embeddings, metadatas, documents, images, uris
)

await self._client._add(ids, self.id, embeddings, metadatas, documents, uris)
await self._client._add(
record_set["ids"],
self.id,
cast(Embeddings, record_set["embeddings"]),
record_set["metadatas"],
record_set["documents"],
record_set["uris"],
)

async def count(self) -> int:
"""The total number of embeddings added to the database
Expand All @@ -91,7 +100,7 @@ async def get(
limit: Optional[int] = None,
offset: Optional[int] = None,
where_document: Optional[WhereDocument] = None,
include: Include = ["metadatas", "documents"],
include: Include = ["metadatas", "documents"], # type: ignore[list-item]
) -> GetResult:
"""Get embeddings and their associate data from the data store. If no ids or where filter is provided returns
all embeddings up to limit starting at offset.
Expand Down Expand Up @@ -144,7 +153,7 @@ async def query(
query_embeddings: Optional[
Union[
OneOrMany[Embedding],
OneOrMany[np.ndarray],
OneOrMany[NDArray[EmbeddingDType]],
]
] = None,
query_texts: Optional[OneOrMany[Document]] = None,
Expand All @@ -153,7 +162,7 @@ async def query(
n_results: int = 10,
where: Optional[Where] = None,
where_document: Optional[WhereDocument] = None,
include: Include = ["metadatas", "documents", "distances"],
include: Include = ["metadatas", "documents", "distances"], # type: ignore[list-item]
) -> QueryResult:
"""Get the n_results nearest neighbor embeddings for provided query_embeddings or query_texts.
Expand Down Expand Up @@ -232,7 +241,7 @@ async def update(
embeddings: Optional[
Union[
OneOrMany[Embedding],
OneOrMany[np.ndarray],
OneOrMany[NDArray[EmbeddingDType]],
]
] = None,
metadatas: Optional[OneOrMany[Metadata]] = None,
Expand All @@ -251,25 +260,32 @@ async def update(
Returns:
None
"""
(
record_set = self._process_upsert_or_update_request(
ids,
embeddings,
metadatas,
documents,
images,
uris,
) = self._validate_and_prepare_update_request(
ids, embeddings, metadatas, documents, images, uris
require_embeddings_or_data=False,
)

await self._client._update(self.id, ids, embeddings, metadatas, documents, uris)
await self._client._update(
self.id,
record_set["ids"],
cast(Embeddings, record_set["embeddings"]),
record_set["metadatas"],
record_set["documents"],
record_set["uris"],
)

async def upsert(
self,
ids: OneOrMany[ID],
embeddings: Optional[
Union[
OneOrMany[Embedding],
OneOrMany[np.ndarray],
OneOrMany[NDArray[EmbeddingDType]],
]
] = None,
metadatas: Optional[OneOrMany[Metadata]] = None,
Expand All @@ -288,23 +304,23 @@ async def upsert(
Returns:
None
"""
(
record_set = self._process_upsert_or_update_request(
ids,
embeddings,
metadatas,
documents,
images,
uris,
) = self._validate_and_prepare_upsert_request(
ids, embeddings, metadatas, documents, images, uris
require_embeddings_or_data=True,
)

await self._client._upsert(
collection_id=self.id,
ids=ids,
embeddings=embeddings,
metadatas=metadatas,
documents=documents,
uris=uris,
ids=record_set["ids"],
embeddings=cast(Embeddings, record_set["embeddings"]),
metadatas=record_set["metadatas"],
documents=record_set["documents"],
uris=record_set["uris"],
)

async def delete(
Expand Down
60 changes: 38 additions & 22 deletions chromadb/api/models/Collection.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from typing import TYPE_CHECKING, Optional, Union
from typing import TYPE_CHECKING, Optional, Union, cast
import numpy as np
from numpy.typing import NDArray

from chromadb.api.models.CollectionCommon import CollectionCommon
from chromadb.api.types import (
URI,
CollectionMetadata,
Embedding,
Embeddings,
EmbeddingDType,
Include,
Metadata,
Document,
Expand Down Expand Up @@ -40,10 +43,10 @@ def count(self) -> int:
def add(
self,
ids: OneOrMany[ID],
embeddings: Optional[ # type: ignore[type-arg]
embeddings: Optional[
Union[
OneOrMany[Embedding],
OneOrMany[np.ndarray],
OneOrMany[NDArray[EmbeddingDType]],
]
] = None,
metadatas: Optional[OneOrMany[Metadata]] = None,
Expand Down Expand Up @@ -71,17 +74,23 @@ def add(
ValueError: If you provide an id that already exists
"""
(
record_set = self._process_add_request(
ids,
embeddings,
metadatas,
documents,
images,
uris,
) = self._validate_and_prepare_embedding_set(
ids, embeddings, metadatas, documents, images, uris
)

self._client._add(ids, self.id, embeddings, metadatas, documents, uris)
self._client._add(
record_set["ids"],
self.id,
cast(Embeddings, record_set["embeddings"]),
record_set["metadatas"],
record_set["documents"],
record_set["uris"],
)

def get(
self,
Expand All @@ -90,7 +99,7 @@ def get(
limit: Optional[int] = None,
offset: Optional[int] = None,
where_document: Optional[WhereDocument] = None,
include: Include = ["metadatas", "documents"],
include: Include = ["metadatas", "documents"], # type: ignore[list-item]
) -> GetResult:
"""Get embeddings and their associate data from the data store. If no ids or where filter is provided returns
all embeddings up to limit starting at offset.
Expand Down Expand Up @@ -152,7 +161,7 @@ def query(
n_results: int = 10,
where: Optional[Where] = None,
where_document: Optional[WhereDocument] = None,
include: Include = ["metadatas", "documents", "distances"],
include: Include = ["metadatas", "documents", "distances"], # type: ignore[list-item]
) -> QueryResult:
"""Get the n_results nearest neighbor embeddings for provided query_embeddings or query_texts.
Expand Down Expand Up @@ -250,17 +259,24 @@ def update(
Returns:
None
"""
(
record_set = self._process_upsert_or_update_request(
ids,
embeddings,
metadatas,
documents,
images,
uris,
) = self._validate_and_prepare_update_request(
ids, embeddings, metadatas, documents, images, uris
require_embeddings_or_data=False,
)

self._client._update(self.id, ids, embeddings, metadatas, documents, uris)
self._client._update(
self.id,
record_set["ids"],
cast(Embeddings, record_set["embeddings"]),
record_set["metadatas"],
record_set["documents"],
record_set["uris"],
)

def upsert(
self,
Expand All @@ -287,23 +303,23 @@ def upsert(
Returns:
None
"""
(
record_set = self._process_upsert_or_update_request(
ids,
embeddings,
metadatas,
documents,
images,
uris,
) = self._validate_and_prepare_upsert_request(
ids, embeddings, metadatas, documents, images, uris
require_embeddings_or_data=True,
)

self._client._upsert(
collection_id=self.id,
ids=ids,
embeddings=embeddings,
metadatas=metadatas,
documents=documents,
uris=uris,
self.id,
record_set["ids"],
cast(Embeddings, record_set["embeddings"]),
record_set["metadatas"],
record_set["documents"],
record_set["uris"],
)

def delete(
Expand Down
Loading

0 comments on commit 88fde75

Please sign in to comment.