-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PERF] Convert embeddings representation to numpy #2803
base: main
Are you sure you want to change the base?
Conversation
Reviewer ChecklistPlease leverage this checklist to ensure your code review is thorough before approving Testing, Bugs, Errors, Logs, Documentation
System Compatibility
Quality
|
Please tag your PR title with one of: [ENH | BUG | DOC | TST | BLD | PERF | TYP | CLN | CHORE]. See https://docs.trychroma.com/contributing#contributing-code-and-ideas |
Please tag your PR title with one of: [ENH | BUG | DOC | TST | BLD | PERF | TYP | CLN | CHORE]. See https://docs.trychroma.com/contributing#contributing-code-and-ideas |
1 similar comment
Please tag your PR title with one of: [ENH | BUG | DOC | TST | BLD | PERF | TYP | CLN | CHORE]. See https://docs.trychroma.com/contributing#contributing-code-and-ideas |
not sure what this referencing; the profiling methodology included in a Detailswould be really helpful :) |
elif encoding == ScalarEncoding.INT32: | ||
return array.array("i", vector).tolist() | ||
return np.frombuffer(vector, dtype=np.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe we should update chromadb/test/property/test_cross_version_persist.py
to leave some entries in the WAL between version changes to assert that these serialization formats are compatible
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like the idea but am having trouble getting the test to pass without purging the log between runs—would spot checking manually that we can serialize as an list and deserialize as a numpy array work instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm yeah probably
also happy to pair to try to get the test working
PyEmbedding = PyVector | ||
Embedding = Vector |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
may be out of scope but the proliferation of types and indirection here is kinda confusing, wondering if just having PyVector
and NpVector
instead would work
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think out of scope for this PR—it would require changing quite a bit more code across the board and introduce more risk. I also think there's a good reason to keep the idea of an Embedding
separate from that of a Vector
. While the former is always the latter, there are usages of Vector
across the code base where it is not an Embedding
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah agree this is weird but I guess the distinction being made is "vector" is a array of numbers vs "embedding" is a domain specific input to chroma that may change. Which makes sense to separate.
21630d8
to
36a3bfc
Compare
chromadb/api/async_fastapi.py
Outdated
@@ -449,7 +449,7 @@ async def _add( | |||
documents: Optional[Documents] = None, | |||
uris: Optional[URIs] = None, | |||
) -> bool: | |||
batch = (ids, embeddings, metadatas, documents, uris) | |||
batch = (ids, [embedding.tolist() for embedding in embeddings] if embeddings is not None else None, metadatas, documents, uris) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When is embeddings ever None here? I don't understand the defensive check because the input type is not None?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
chromadb/api/async_fastapi.py
Outdated
@@ -465,7 +465,7 @@ async def _update( | |||
documents: Optional[Documents] = None, | |||
uris: Optional[URIs] = None, | |||
) -> bool: | |||
batch = (ids, embeddings, metadatas, documents, uris) | |||
batch = (ids, [embedding.tolist() for embedding in embeddings] if embeddings is not None else None, metadatas, documents, uris) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
embeddings
can be None
here—both given the method definition and if you trace through the codepath. In _validate_and_prepare_update_request
, we call _validate_embedding_set
with require_embeddings_or_data
set to False
, which means we can end up sending self._client._update
a None
for embeddings
.
chromadb/api/fastapi.py
Outdated
@@ -418,7 +418,7 @@ def _add( | |||
Adds a batch of embeddings to the database | |||
- pass in column oriented data lists | |||
""" | |||
batch = (ids, embeddings, metadatas, documents, uris) | |||
batch = (ids, [embedding.tolist() for embedding in embeddings] if embeddings is not None else None, metadatas, documents, uris) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe move the [np] conversion out into a convenience method
@@ -136,14 +137,14 @@ def peek(self, limit: int = 10) -> GetResult: | |||
Returns: | |||
GetResult: A GetResult object containing the results. | |||
""" | |||
return self._client._peek(self.id, limit) | |||
return self._transform_peek_response(self._client._peek(self.id, limit)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does async collection not need this change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch!
@@ -327,6 +334,9 @@ def _transform_get_response( | |||
): | |||
response["data"] = self._data_loader(response["uris"]) | |||
|
|||
if ("embeddings" in include and response["embeddings"] is not None): | |||
response["embeddings"] = np.array(response["embeddings"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should note: this is a breaking API change and needs documentation
if isinstance(embeddings, np.ndarray): | ||
return embeddings.tolist() # type: ignore | ||
return embeddings # type: ignore | ||
return cast(Embeddings, [np.array(embedding) for embedding in embeddings]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
question: do we validate the shape of PyEmbedding anywhere?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean do we validate that the embeddings the user provide have the same dimensionality? If so, we do in chromadb/api/segment.py
on line 828 in _validate_dimension
.
Description of changes
Summarize the changes made by this PR.
Improvements & Bug fixes
Notes
The conversion of a numpy array to a python list via
.tolist()
consumes an excess amount of memory, and so the focus of this PR is to reduce the number of these operations by representing embeddings as numpy arrays across the board.The format for
Embeddings
changes fromList[List[Union[float, int]]]
toList[NDArray[Union[int32, float32]]]
.Using a 2D numpy array instead of a Python list of 1D numpy arrays would be slightly more performant. However, it would require deep, fundamental, and widespread changes to not only our code but also testing frameworks. I have a commit that starts to go down this path (if you're interested in seeing), and it's not pretty. The changes introduce an incredible amount of risk across the board for little to no gain on the vector we are trying to improve and thus is not worth it.
FastAPI does not support serialization of numpy arrays. Thus, upon encoding/decoding embeddings through our middleware, we have no option but to convert back to python lists. This is acceptable, however, because these operations only happen within Distributed.
This PR introduces a breaking change to the API in that embeddings will now be returned as 2D numpy arrays. It makes sense to me both from a memory usage issue perspective (not calling
.tolist()
upon returning a response) and also generally provides better performance within users.Embeddings will now be internally standardized to
float32
/int32
upon serialization (for RPC and byte encoding), which will could cause a little bit of noise within distance functions depending on what the user provides. We should communicate this change in the next version release.Because of the aforementioned FastAPI serialization issue and the fact that structures like
QueryResult
andGetResult
are threaded all the way across the stack through our middleware, there is a new risk introduced. We could fork Local/Distributed Chroma and create new structures for the former that use numpy arrays. However, this is unideal to me. I decided instead to let these structures support both numpy arrays and pythonic lists for embeddings. When were are on either side of the middleware, we must convert between numpy arrays and lists. This isn’t an issue in of itself, but makes it such that we always have to remember to perform the conversion when using the structures, lest we risk returning embeddings as pythonic lists within Distributed Chroma. New development of clients seems unlikely for now, so this feels like an acceptable risk.Using the same profiling method, over 10 runs the average memory used was originally 362 MB, going to 264 MB with changes (representing a 28% decrease in memory usage).
Performance details
I ran the following code 10 times each on main and on my branch and averaged the results to test the performance. This was the same code as used in the original [issue](https://github.com//issues/2665)Test plan
How are these changes tested?
pytest
for python,yarn test
for js,cargo test
for rust