Skip to content

Commit

Permalink
[ENH] Generate IDs when not given in upsert and add (#2693)
Browse files Browse the repository at this point in the history
## Description of changes

*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
- Refactored functions to adhere to SRP to reduce the level of
abstractions.
 - New functionality
- when a user uses add and upsert on an collection, they are no longer
required to pass in an array of IDs. They will be automatically
generated if not given.

## Test plan
*How are these changes tested?*

- [ ] Tests pass locally with `pytest` for python, `yarn test` for js,
`cargo test` for rust

## Documentation Changes
*Are all docstrings for user-facing APIs updated if required? Do we need
to make documentation changes in the [docs
repository](https://github.com/chroma-core/docs)?*
  • Loading branch information
spikechroma authored Aug 21, 2024
1 parent 781622b commit 4a5b473
Show file tree
Hide file tree
Showing 7 changed files with 117 additions and 106 deletions.
17 changes: 11 additions & 6 deletions chromadb/api/models/AsyncCollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from chromadb.api.types import (
URI,
AddResult,
CollectionMetadata,
Embedding,
Include,
Expand All @@ -33,7 +34,7 @@
class AsyncCollection(CollectionCommon["AsyncServerAPI"]):
async def add(
self,
ids: OneOrMany[ID],
ids: Optional[OneOrMany[ID]] = None,
embeddings: Optional[
Union[
OneOrMany[Embedding],
Expand All @@ -44,7 +45,7 @@ async def add(
documents: Optional[OneOrMany[Document]] = None,
images: Optional[OneOrMany[Image]] = None,
uris: Optional[OneOrMany[URI]] = None,
) -> None:
) -> AddResult:
"""Add embeddings to the data store.
Args:
ids: The ids of the embeddings you wish to add
Expand Down Expand Up @@ -75,14 +76,18 @@ async def add(
)

await self._client._add(
embedding_set["ids"],
cast(IDs, embedding_set["ids"]),
self.id,
cast(Embeddings, embedding_set["embeddings"]),
embedding_set["metadatas"],
embedding_set["documents"],
embedding_set["uris"],
)

return {
"ids": embedding_set["ids"],
}

async def count(self) -> int:
"""The total number of embeddings added to the database
Expand Down Expand Up @@ -259,7 +264,7 @@ async def update(
Returns:
None
"""
embedding_set = self._process_update_request(
embedding_set = self._process_upsert_or_update_request(
ids, embeddings, metadatas, documents, images, uris
)

Expand Down Expand Up @@ -297,13 +302,13 @@ async def upsert(
Returns:
None
"""
embedding_set = self._process_upsert_request(
embedding_set = self._process_upsert_or_update_request(
ids, embeddings, metadatas, documents, images, uris
)

await self._client._upsert(
collection_id=self.id,
ids=embedding_set["ids"],
ids=cast(IDs, embedding_set["ids"]),
embeddings=cast(Embeddings, embedding_set["embeddings"]),
metadatas=embedding_set["metadatas"],
documents=embedding_set["documents"],
Expand Down
19 changes: 12 additions & 7 deletions chromadb/api/models/Collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Include,
Metadata,
Document,
AddResult,
Image,
Where,
IDs,
Expand Down Expand Up @@ -40,7 +41,7 @@ def count(self) -> int:

def add(
self,
ids: OneOrMany[ID],
ids: Optional[OneOrMany[ID]] = None,
embeddings: Optional[ # type: ignore[type-arg]
Union[
OneOrMany[Embedding],
Expand All @@ -51,7 +52,7 @@ def add(
documents: Optional[OneOrMany[Document]] = None,
images: Optional[OneOrMany[Image]] = None,
uris: Optional[OneOrMany[URI]] = None,
) -> None:
) -> AddResult:
"""Add embeddings to the data store.
Args:
ids: The ids of the embeddings you wish to add
Expand Down Expand Up @@ -82,14 +83,18 @@ def add(
)

self._client._add(
embedding_set["ids"],
cast(IDs, embedding_set["ids"]),
self.id,
cast(Embeddings, embedding_set["embeddings"]),
embedding_set["metadatas"],
embedding_set["documents"],
embedding_set["uris"],
)

return {
"ids": embedding_set["ids"],
}

def get(
self,
ids: Optional[OneOrMany[ID]] = None,
Expand Down Expand Up @@ -257,13 +262,13 @@ def update(
Returns:
None
"""
embedding_set = self._process_update_request(
embedding_set = self._process_upsert_or_update_request(
ids, embeddings, metadatas, documents, images, uris
)

self._client._update(
self.id,
embedding_set["ids"],
cast(IDs, embedding_set["ids"]),
cast(Embeddings, embedding_set["embeddings"]),
embedding_set["metadatas"],
embedding_set["documents"],
Expand Down Expand Up @@ -295,13 +300,13 @@ def upsert(
Returns:
None
"""
embedding_set = self._process_upsert_request(
embedding_set = self._process_upsert_or_update_request(
ids, embeddings, metadatas, documents, images, uris
)

self._client._upsert(
collection_id=self.id,
ids=embedding_set["ids"],
ids=cast(IDs, embedding_set["ids"]),
embeddings=cast(Embeddings, embedding_set["embeddings"]),
metadatas=embedding_set["metadatas"],
documents=embedding_set["documents"],
Expand Down
126 changes: 47 additions & 79 deletions chromadb/api/models/CollectionCommon.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
cast,
)
import numpy as np
from uuid import UUID
from uuid import UUID, uuid4

import chromadb.utils.embedding_functions as ef
from chromadb.api.types import (
Expand Down Expand Up @@ -151,7 +151,7 @@ def get_model(self) -> CollectionModel:

def _unpack_embedding_set(
self,
ids: OneOrMany[ID],
ids: Optional[OneOrMany[ID]],
embeddings: Optional[
Union[
OneOrMany[Embedding],
Expand Down Expand Up @@ -181,7 +181,7 @@ def _unpack_embedding_set(

def _validate_embedding_set(
self,
ids: IDs,
ids: Optional[IDs],
embeddings: Optional[Embeddings],
metadatas: Optional[Metadatas],
documents: Optional[Documents],
Expand All @@ -197,10 +197,6 @@ def _validate_embedding_set(
validate_metadatas(metadatas) if metadatas is not None else None
)

valid_documents = maybe_cast_one_to_many_document(documents)
valid_images = maybe_cast_one_to_many_image(images)
valid_uris = maybe_cast_one_to_many_uri(uris)

# Check that one of embeddings or ducuments or images is provided
if require_embeddings_or_data:
if (
Expand All @@ -214,7 +210,7 @@ def _validate_embedding_set(
)

# Only one of documents or images can be provided
if valid_documents is not None and valid_images is not None:
if documents is not None and images is not None:
raise ValueError("You can only provide documents or images, not both.")

# Check that, if they're provided, the lengths of the arrays match the length of ids
Expand All @@ -226,17 +222,17 @@ def _validate_embedding_set(
raise ValueError(
f"Number of metadatas {len(valid_metadatas)} must match number of ids {len(valid_ids)}"
)
if valid_documents is not None and len(valid_documents) != len(valid_ids):
if documents is not None and len(documents) != len(valid_ids):
raise ValueError(
f"Number of documents {len(valid_documents)} must match number of ids {len(valid_ids)}"
f"Number of documents {len(documents)} must match number of ids {len(valid_ids)}"
)
if valid_images is not None and len(valid_images) != len(valid_ids):
if images is not None and len(images) != len(valid_ids):
raise ValueError(
f"Number of images {len(valid_images)} must match number of ids {len(valid_ids)}"
f"Number of images {len(images)} must match number of ids {len(valid_ids)}"
)
if valid_uris is not None and len(valid_uris) != len(valid_ids):
if uris is not None and len(uris) != len(valid_ids):
raise ValueError(
f"Number of uris {len(valid_uris)} must match number of ids {len(valid_ids)}"
f"Number of uris {len(uris)} must match number of ids {len(valid_ids)}"
)

def _prepare_embeddings(
Expand Down Expand Up @@ -426,9 +422,36 @@ def _update_model_after_modify_success(
if metadata:
self._model["metadata"] = metadata

@staticmethod
def _generate_ids_when_not_present(
ids: Optional[IDs],
documents: Optional[Documents],
uris: Optional[URIs],
images: Optional[Images],
embeddings: Optional[Embeddings],
) -> IDs:
if ids is not None and len(ids) > 0:
return ids

n = 0
if documents is not None:
n = len(documents)
elif uris is not None:
n = len(uris)
elif images is not None:
n = len(images)
elif embeddings is not None:
n = len(embeddings)

generated_ids = []
for _ in range(n):
generated_ids.append(str(uuid4()))

return generated_ids

def _process_add_request(
self,
ids: OneOrMany[ID],
ids: Optional[OneOrMany[ID]],
embeddings: Optional[
Union[
OneOrMany[Embedding],
Expand All @@ -455,72 +478,16 @@ def _process_add_request(
else None
)

self._validate_embedding_set(
generated_ids = self._generate_ids_when_not_present(
unpacked_embedding_set["ids"],
normalized_embeddings,
unpacked_embedding_set["metadatas"],
unpacked_embedding_set["documents"],
unpacked_embedding_set["images"],
unpacked_embedding_set["uris"],
require_embeddings_or_data=False,
)

prepared_embeddings = self._prepare_embeddings(
normalized_embeddings,
unpacked_embedding_set["documents"],
unpacked_embedding_set["images"],
unpacked_embedding_set["uris"],
)

return {
"ids": unpacked_embedding_set["ids"],
"embeddings": prepared_embeddings,
"metadatas": unpacked_embedding_set["metadatas"],
"documents": unpacked_embedding_set["documents"],
"images": unpacked_embedding_set["images"],
"uris": unpacked_embedding_set["uris"],
}

def _prepare_update_request(
self,
embeddings: Optional[Embeddings],
documents: Optional[Documents],
images: Optional[Images],
) -> Embeddings:
if embeddings is None:
if documents is not None:
embeddings = self._embed(input=documents)
elif images is not None:
embeddings = self._embed(input=images)

return cast(Embeddings, embeddings)

def _process_update_request(
self,
ids: OneOrMany[ID],
embeddings: Optional[ # type: ignore[type-arg]
Union[
OneOrMany[Embedding],
OneOrMany[np.ndarray],
]
],
metadatas: Optional[OneOrMany[Metadata]],
documents: Optional[OneOrMany[Document]],
images: Optional[OneOrMany[Image]],
uris: Optional[OneOrMany[URI]],
) -> EmbeddingSet:
unpacked_embedding_set = self._unpack_embedding_set(
ids, embeddings, metadatas, documents, images, uris
)

normalized_embeddings = (
self._normalize_embeddings(unpacked_embedding_set["embeddings"])
if unpacked_embedding_set["embeddings"] is not None
else None
normalized_embeddings,
)

self._validate_embedding_set(
unpacked_embedding_set["ids"],
generated_ids,
normalized_embeddings,
unpacked_embedding_set["metadatas"],
unpacked_embedding_set["documents"],
Expand All @@ -529,22 +496,23 @@ def _process_update_request(
require_embeddings_or_data=False,
)

prepared_embeddings = self._prepare_update_request(
prepared_embeddings = self._prepare_embeddings(
normalized_embeddings,
unpacked_embedding_set["documents"],
unpacked_embedding_set["images"],
unpacked_embedding_set["uris"],
)

return {
"ids": unpacked_embedding_set["ids"],
"ids": generated_ids,
"embeddings": prepared_embeddings,
"metadatas": unpacked_embedding_set["metadatas"],
"documents": unpacked_embedding_set["documents"],
"images": unpacked_embedding_set["images"],
"uris": unpacked_embedding_set["uris"],
}

def _prepare_upsert_request(
def _prepare_upsert_or_update_request(
self,
embeddings: Optional[Embeddings],
documents: Optional[Documents],
Expand All @@ -558,7 +526,7 @@ def _prepare_upsert_request(

return cast(Embeddings, embeddings)

def _process_upsert_request(
def _process_upsert_or_update_request(
self,
ids: OneOrMany[ID],
embeddings: Optional[ # type: ignore[type-arg]
Expand Down Expand Up @@ -592,7 +560,7 @@ def _process_upsert_request(
require_embeddings_or_data=False,
)

prepared_embeddings = self._prepare_upsert_request(
prepared_embeddings = self._prepare_upsert_or_update_request(
normalized_embeddings,
unpacked_embedding_set["documents"],
unpacked_embedding_set["images"],
Expand Down
Loading

0 comments on commit 4a5b473

Please sign in to comment.