Skip to content

Commit

Permalink
[CLN] Support numpy >=2.0 (#2811)
Browse files Browse the repository at this point in the history
## Description of changes

*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
- Supports numpy 2.0 now that onnxruntime supports it by removing our
limit of <2.0
- Change usage of `float_` to `float64`. These are supported in both <
and >= 2.0. However `float_` is only supported by < 2.0. In < 2.0 the _
types are shorthand for the 64bit wide datatypes. I verified this. This
makes the changes cleaner than #2776.
	 
<img width="300" alt="Screenshot 2024-09-17 at 10 15 21 AM"
src="https://github.com/user-attachments/assets/8e290733-419c-4d3c-9d71-dc76d01618d7">

	 
 - New functionality
	 - None

## Test plan
*How are these changes tested?*
Existing tests
- [x] Tests pass locally with `pytest` for python, `yarn test` for js,
`cargo test` for rust

## Documentation Changes
None
  • Loading branch information
HammadB authored Sep 17, 2024
1 parent efea481 commit 9ab0196
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 10 deletions.
2 changes: 1 addition & 1 deletion chromadb/api/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def maybe_cast_one_to_many_document(target: OneOrMany[Document]) -> Documents:


# Images
ImageDType = Union[np.uint, np.int_, np.float_] # type: ignore[name-defined]
ImageDType = Union[np.uint, np.int64, np.float64]
Image = NDArray[ImageDType]
Images = List[Image]

Expand Down
12 changes: 6 additions & 6 deletions chromadb/test/ef/test_multimodal_ef.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# then hashes them to a fixed dimension.
class hashing_multimodal_ef(EmbeddingFunction[Embeddable]):
def __init__(self) -> None:
self._hef = hashing_embedding_function(dim=10, dtype=np.float_)
self._hef = hashing_embedding_function(dim=10, dtype=np.float64)

def __call__(self, input: Embeddable) -> Embeddings:
to_texts = [str(i) for i in input]
Expand All @@ -29,7 +29,7 @@ def __call__(self, input: Embeddable) -> Embeddings:


def random_image() -> Image:
return np.random.randint(0, 255, size=(10, 10, 3), dtype=np.int32)
return np.random.randint(0, 255, size=(10, 10, 3), dtype=np.int64)


def random_document() -> Document:
Expand Down Expand Up @@ -82,7 +82,7 @@ def test_multimodal(

# get() should return all the documents and images
# ids corresponding to images should not have documents
get_result = multimodal_collection.get(include=["documents"])
get_result = multimodal_collection.get(include=["documents"]) # type: ignore[list-item]
assert len(get_result["ids"]) == len(document_ids) + len(image_ids)
for i, id in enumerate(get_result["ids"]):
assert id in document_ids or id in image_ids
Expand Down Expand Up @@ -124,14 +124,14 @@ def test_multimodal(

# Query with images
query_result = multimodal_collection.query(
query_images=[query_image], n_results=n_query_results, include=["documents"]
query_images=[query_image], n_results=n_query_results, include=["documents"] # type: ignore[list-item]
)

assert query_result["ids"][0] == nearest_image_neighbor_ids

# Query with documents
query_result = multimodal_collection.query(
query_texts=[query_document], n_results=n_query_results, include=["documents"]
query_texts=[query_document], n_results=n_query_results, include=["documents"] # type: ignore[list-item]
)

assert query_result["ids"][0] == nearest_document_neighbor_ids
Expand All @@ -152,6 +152,6 @@ def test_multimodal_update_with_image(

multimodal_collection.update(ids=id, images=image)

get_result = multimodal_collection.get(ids=id, include=["documents"])
get_result = multimodal_collection.get(ids=id, include=["documents"]) # type: ignore[list-item]
assert get_result["documents"] is not None
assert get_result["documents"][0] is None
2 changes: 1 addition & 1 deletion clients/python/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
httpx>=0.27.0
numpy >= 1.22.5, < 2.0.0
numpy >= 1.22.5
opentelemetry-api>=1.2.0
opentelemetry-exporter-otlp-proto-grpc>=1.2.0
opentelemetry-sdk>=1.2.0
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ dependencies = [
'chroma-hnswlib==0.7.6',
'fastapi >= 0.95.2',
'uvicorn[standard] >= 0.18.3',
'numpy >= 1.22.5, < 2.0.0',
'numpy >= 1.22.5',
'posthog >= 2.4.0',
'typing_extensions >= 4.5.0',
'onnxruntime >= 1.14.1',
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ httpx>=0.27.0
importlib-resources
kubernetes>=28.1.0
mmh3>=4.0.1
numpy>=1.22.5, <2.0.0
numpy>=1.22.5
onnxruntime>=1.14.1
opentelemetry-api>=1.2.0
opentelemetry-exporter-otlp-proto-grpc>=1.24.0
Expand Down

0 comments on commit 9ab0196

Please sign in to comment.