Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Use numpy.ndarray and numpy dtypes instead of memref types in Python API #1786

Draft
wants to merge 21 commits into
base: main
Choose a base branch
from

Conversation

hunhoffe
Copy link
Collaborator

@hunhoffe hunhoffe commented Sep 23, 2024

Summary

This PR optionally allows a Python user to use numpy ndarray and numpy dtype types instead of memref and MLIR types in their designs.

Specifically, changes in this PR include:

  • Porting all programming_examples to use numpy types (except those examples using bfloat16 as numpy doesn't support that type directly - those examples use the existing memref/MLIR types without modification)
  • A new example, programming_examples/basic/passthrough_pykernel, which uses a mixture of types (to make sure a mix works ok) and is also the only example (that I know of) that shows how to create/call a function defined in Python

If these changes are accepted, I'll do the work to change the programming_guide to use the numpy types, as well as port at least half of the python tests in test to similarly use numpy types. I can include this in a follow-on PR or add it to this PR, depending on advice from reviewers.

Example

For example, to define a memref type in the past (such as for use in initializing an object FIFO) as well as external function argument types, you would write something like:

        # define types
        memRef_ty = T.memref(lineWidthInBytes, T.ui8())

        passThroughLine = external_func(
            "passThroughLine", inputs=[memRef_ty, memRef_ty, T.i32()]
        )

Now, you can instead write a type definition like so:

        line_ty = np.ndarray[(lineWidthInBytes,), np.dtype[np.int8]]

        passThroughLine = external_func(
            "passThroughLine", inputs=[line_ty, line_ty, np.int32]
        )

Requirements

  • Python 3.10 or higher for type annotations with bracket notation
  • numpy 2.1 or higher for shape annotations for numpy.ndarray type

Internally, some of the code would look better if we bumped all the way to python 3.12+ so we could use type annotations with generics with bracket notation, but so far that is not necessary.

@hunhoffe hunhoffe changed the title Use np.ndarray and np dtypes instead of memref types in Python API Use numpy.ndarray and numpy dtypes instead of memref types in Python API Sep 23, 2024
@hunhoffe hunhoffe force-pushed the hide-memref-types branch 2 times, most recently from 4706a5d to 045764d Compare September 24, 2024 13:59
@hunhoffe hunhoffe closed this Sep 24, 2024
@hunhoffe hunhoffe reopened this Sep 24, 2024
@hunhoffe hunhoffe mentioned this pull request Sep 30, 2024
@hunhoffe hunhoffe changed the title Use numpy.ndarray and numpy dtypes instead of memref types in Python API [WIP] Use numpy.ndarray and numpy dtypes instead of memref types in Python API Sep 30, 2024
Copy link
Contributor

github-actions bot commented Sep 30, 2024

Coverage Report

Created: 2024-10-01 17:55

Click here for information about interpreting this report.

FilenameFunction CoverageLine CoverageRegion CoverageBranch Coverage
Totals- - - -
Generated by llvm-cov -- llvm version 14.0.0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant