Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Generate IDs when not given in add #2699

Open
wants to merge 60 commits into
base: spike/generate_ids_move_validation
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
b5b90b1
lint
spikechroma Aug 22, 2024
ff09b8a
fix types
spikechroma Aug 26, 2024
9c5fd30
update interface
spikechroma Aug 27, 2024
d13a915
fix tests
spikechroma Aug 27, 2024
0e12e97
update validation logic
spikechroma Aug 28, 2024
12658b8
fix lint
spikechroma Aug 28, 2024
7a43952
add back deleted
spikechroma Aug 28, 2024
7a4b565
update validate batch type compatability
spikechroma Aug 28, 2024
b327557
update fastapi types
spikechroma Aug 28, 2024
025ddd9
remove default to empty array
spikechroma Aug 28, 2024
d70bfb9
fix persist test
spikechroma Aug 28, 2024
40a1f14
make id optional and fix record set generation strategy to allow for …
spikechroma Aug 28, 2024
2c5182e
make changes for optional ids
spikechroma Aug 28, 2024
492d9ec
make changes for optional ids
spikechroma Aug 28, 2024
7fccfdb
make changes for optional ids
spikechroma Aug 28, 2024
eed5bd7
make changes for optional ids
spikechroma Aug 28, 2024
bbdb027
fix type error
spikechroma Aug 28, 2024
8896672
fix type error
spikechroma Aug 28, 2024
b29e4a4
fix tests
spikechroma Aug 28, 2024
06c7d29
reduce entropy in property testing
spikechroma Aug 28, 2024
0b39727
update ts client
spikechroma Aug 28, 2024
4dc3003
lint
spikechroma Aug 28, 2024
3a2f314
create a func for getting record_set len
spikechroma Aug 28, 2024
ac42811
fix tests
spikechroma Aug 29, 2024
abcae71
add comment
spikechroma Aug 29, 2024
e587eec
fix logic error in add_embeddings
spikechroma Aug 29, 2024
a18dc14
lint
spikechroma Aug 29, 2024
9a9870f
update property test
spikechroma Aug 30, 2024
e44d180
fix ndarray error
spikechroma Aug 30, 2024
38ede4e
fix tests
spikechroma Aug 30, 2024
ad57314
update ts client and fix cross version persist test
spikechroma Aug 30, 2024
ef567cc
minor updates
spikechroma Aug 30, 2024
d4a0288
revert changes
spikechroma Sep 2, 2024
d7f07dd
create a new func for ensuring record set consistency
spikechroma Sep 3, 2024
fb4481e
fix broken tests
spikechroma Sep 3, 2024
1b07982
update property tests to handle validation error
spikechroma Sep 9, 2024
9628950
fix test fail
spikechroma Sep 9, 2024
8fd8a30
fix broken tests
spikechroma Sep 10, 2024
3b99b28
revert changes
spikechroma Sep 10, 2024
9c7c36c
fix conflicts
spikechroma Sep 10, 2024
8d5d487
update doc strings, error messages and ignore tags
spikechroma Sep 10, 2024
6961705
fix tests
spikechroma Sep 10, 2024
740d7c1
fix tags and update ids handling logic in cross version test
spikechroma Sep 10, 2024
d9eb963
update count records function
spikechroma Sep 10, 2024
1e346e8
refactor
spikechroma Sep 10, 2024
1a45394
add additional validations
spikechroma Sep 11, 2024
856596a
update logic for getting n items from record set state
spikechroma Sep 11, 2024
2e3e0cb
modify ann accuracy and count
spikechroma Sep 11, 2024
88ebd37
fix error
spikechroma Sep 11, 2024
77ec112
fix tag
spikechroma Sep 11, 2024
edadb92
fix tag
spikechroma Sep 11, 2024
82658fe
fix tag
spikechroma Sep 11, 2024
38fc6b4
turn on can ids be empty for test add medium
spikechroma Sep 11, 2024
a3ba0f8
update persist test to take in state record set
spikechroma Sep 11, 2024
4a1fcc5
add id validations
spikechroma Sep 11, 2024
ba9681b
add tests for upsert and update with none ids
spikechroma Sep 11, 2024
4de0196
lint
spikechroma Sep 11, 2024
2369c54
change func header
spikechroma Sep 11, 2024
4b8608d
change func header
spikechroma Sep 11, 2024
b338b1e
update tests
spikechroma Sep 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions chromadb/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
URIs,
Where,
QueryResult,
AddResult,
GetResult,
WhereDocument,
)
Expand Down Expand Up @@ -115,13 +116,13 @@ def delete_collection(
@abstractmethod
def _add(
self,
ids: IDs,
collection_id: UUID,
embeddings: Embeddings,
ids: Optional[IDs] = None,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
) -> bool:
) -> AddResult:
"""[Internal] Add embeddings to a collection specified by UUID.
If (some) ids already exist, only the new embeddings will be added.

Expand Down
5 changes: 3 additions & 2 deletions chromadb/api/async_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
Where,
QueryResult,
GetResult,
AddResult,
WhereDocument,
)
from chromadb.config import Component, Settings
Expand Down Expand Up @@ -106,13 +107,13 @@ async def delete_collection(
@abstractmethod
async def _add(
self,
ids: IDs,
collection_id: UUID,
embeddings: Embeddings,
ids: Optional[IDs] = None,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
) -> bool:
) -> AddResult:
"""[Internal] Add embeddings to a collection specified by UUID.
If (some) ids already exist, only the new embeddings will be added.

Expand Down
5 changes: 3 additions & 2 deletions chromadb/api/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
EmbeddingFunction,
Embeddings,
GetResult,
AddResult,
IDs,
Include,
Loadable,
Expand Down Expand Up @@ -260,13 +261,13 @@ async def delete_collection(
@override
async def _add(
self,
ids: IDs,
collection_id: UUID,
embeddings: Embeddings,
ids: Optional[IDs] = None,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
) -> bool:
) -> AddResult:
spikechroma marked this conversation as resolved.
Show resolved Hide resolved
return await self._server._add(
ids=ids,
collection_id=collection_id,
Expand Down
102 changes: 72 additions & 30 deletions chromadb/api/async_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from uuid import UUID
import urllib.parse
import orjson
from typing import Any, Optional, cast, Tuple, Sequence, Dict
from typing import Any, Optional, cast, Sequence, Dict
import logging
import httpx
from overrides import override
Expand Down Expand Up @@ -30,9 +30,11 @@
Where,
WhereDocument,
GetResult,
AddResult,
QueryResult,
CollectionMetadata,
validate_batch,
validate_batch_size,
RecordSet,
)


Expand Down Expand Up @@ -411,16 +413,10 @@ async def _delete(

return cast(IDs, resp_json)

@trace_method("AsyncFastAPI._submit_batch", OpenTelemetryGranularity.ALL)
async def _submit_batch(
@trace_method("AsyncFastAPI._submit_record_set", OpenTelemetryGranularity.ALL)
async def _submit_record_set(
self,
batch: Tuple[
IDs,
Optional[Embeddings],
Optional[Metadatas],
Optional[Documents],
Optional[URIs],
],
record_set: RecordSet,
url: str,
) -> Any:
"""
Expand All @@ -430,29 +426,53 @@ async def _submit_batch(
"post",
url,
json={
"ids": batch[0],
"embeddings": batch[1],
"metadatas": batch[2],
"documents": batch[3],
"uris": batch[4],
"ids": record_set["ids"],
"embeddings": record_set["embeddings"],
"metadatas": record_set["metadatas"],
"documents": record_set["documents"],
"uris": record_set["uris"],
},
)

@trace_method("AsyncFastAPI._add", OpenTelemetryGranularity.ALL)
@override
async def _add(
self,
ids: IDs,
collection_id: UUID,
embeddings: Embeddings,
ids: Optional[IDs] = None,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
) -> bool:
batch = (ids, embeddings, metadatas, documents, uris)
validate_batch(batch, {"max_batch_size": await self.get_max_batch_size()})
await self._submit_batch(batch, "/collections/" + str(collection_id) + "/add")
return True
) -> AddResult:
record_set: RecordSet = {
"ids": ids,
"embeddings": embeddings,
"metadatas": metadatas,
"documents": documents,
"uris": uris,
"images": None,
}

validate_batch_size(
record_set, {"max_batch_size": await self.get_max_batch_size()}
)

resp_json = await self._make_request(
"post",
"/collections/" + str(collection_id) + "/add",
json={
"ids": record_set["ids"],
"embeddings": record_set["embeddings"],
"metadatas": record_set["metadatas"],
"documents": record_set["documents"],
"uris": record_set["uris"],
},
)

return AddResult(
ids=resp_json["ids"],
)

@trace_method("AsyncFastAPI._update", OpenTelemetryGranularity.ALL)
@override
Expand All @@ -465,11 +485,21 @@ async def _update(
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
) -> bool:
batch = (ids, embeddings, metadatas, documents, uris)
validate_batch(batch, {"max_batch_size": await self.get_max_batch_size()})
record_set: RecordSet = {
"ids": ids,
"embeddings": embeddings,
"metadatas": metadatas,
"documents": documents,
"uris": uris,
"images": None,
}

validate_batch_size(
record_set, {"max_batch_size": await self.get_max_batch_size()}
)

await self._submit_batch(
batch, "/collections/" + str(collection_id) + "/update"
await self._submit_record_set(
record_set, "/collections/" + str(collection_id) + "/update"
)

return True
Expand All @@ -485,11 +515,23 @@ async def _upsert(
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
) -> bool:
batch = (ids, embeddings, metadatas, documents, uris)
validate_batch(batch, {"max_batch_size": await self.get_max_batch_size()})
await self._submit_batch(
batch, "/collections/" + str(collection_id) + "/upsert"
record_set: RecordSet = {
"ids": ids,
"embeddings": embeddings,
"metadatas": metadatas,
"documents": documents,
"uris": uris,
"images": None,
}

validate_batch_size(
record_set, {"max_batch_size": await self.get_max_batch_size()}
)

await self._submit_record_set(
record_set, "/collections/" + str(collection_id) + "/upsert"
)

return True

@trace_method("AsyncFastAPI._query", OpenTelemetryGranularity.ALL)
Expand Down
5 changes: 3 additions & 2 deletions chromadb/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
Loadable,
Metadatas,
QueryResult,
AddResult,
URIs,
)
from chromadb.config import Settings, System
Expand Down Expand Up @@ -208,13 +209,13 @@ def delete_collection(
@override
def _add(
self,
ids: IDs,
collection_id: UUID,
embeddings: Embeddings,
ids: Optional[IDs] = None,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
) -> bool:
) -> AddResult:
return self._server._add(
ids=ids,
collection_id=collection_id,
Expand Down
Loading
Loading