diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h index 7e47e54e7361d5..859a4f0dd9f52c 100644 --- a/mlir/include/mlir-c/Dialect/SparseTensor.h +++ b/mlir/include/mlir-c/Dialect/SparseTensor.h @@ -51,11 +51,10 @@ MLIR_CAPI_EXPORTED bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr); /// Creates a `sparse_tensor.encoding` attribute with the given parameters. -/// TODO: add a version that supplied lvlToDim when it cannot be inferred MLIR_CAPI_EXPORTED MlirAttribute mlirSparseTensorEncodingAttrGet( MlirContext ctx, intptr_t lvlRank, enum MlirSparseTensorDimLevelType const *lvlTypes, MlirAffineMap dimToLvl, - int posWidth, int crdWidth); + MlirAffineMap lvlTodim, int posWidth, int crdWidth); /// Returns the level-rank of the `sparse_tensor.encoding` attribute. MLIR_CAPI_EXPORTED intptr_t diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h index cbca0a7f8cc0e3..6e834426b44176 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h @@ -160,6 +160,19 @@ inline bool hasAnySparseOperandOrResult(Operation *op) { return hasAnySparseOperand(op) || hasAnySparseResult(op); } +// +// Inference. +// + +/// Given the dimToLvl map, infers the lvlToDim map, or returns +/// empty Affine map when inference fails. +AffineMap inferLvlToDim(AffineMap dimToLvl, MLIRContext *context); + +/// Returns the lvlToDim map for the given dimToLvl map specific +/// to the block sparse cases. +/// Asserts on failure (so only use when known to succeed). +AffineMap inverseBlockSparsity(AffineMap dimToLvl, MLIRContext *context); + // // Reordering. // diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td index 38c7200afb41ff..47fd18a689d5a8 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td @@ -307,6 +307,9 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding", "AffineMap":$lvlToDim, "unsigned":$posWidth, "unsigned":$crdWidth), [{ + if (!lvlToDim) { + lvlToDim = ::mlir::sparse_tensor::inferLvlToDim(dimToLvl, $_ctxt); + } return $_get($_ctxt, lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth, ArrayRef<::mlir::sparse_tensor::SparseTensorDimSliceAttr>{}); }]> diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp index 8e9e0b6baf76c2..9bde3a443ecfec 100644 --- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp +++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp @@ -41,16 +41,17 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) { .def_classmethod( "get", [](py::object cls, std::vector lvlTypes, - std::optional dimToLvl, int posWidth, int crdWidth, + std::optional dimToLvl, + std::optional lvlToDim, int posWidth, int crdWidth, MlirContext context) { - // TODO: provide dimToLvl return cls(mlirSparseTensorEncodingAttrGet( context, lvlTypes.size(), lvlTypes.data(), - dimToLvl ? *dimToLvl : MlirAffineMap{nullptr}, posWidth, + dimToLvl ? *dimToLvl : MlirAffineMap{nullptr}, + lvlToDim ? *lvlToDim : MlirAffineMap{nullptr}, posWidth, crdWidth)); }, py::arg("cls"), py::arg("lvl_types"), py::arg("dim_to_lvl"), - py::arg("pos_width"), py::arg("crd_width"), + py::arg("lvl_to_dim"), py::arg("pos_width"), py::arg("crd_width"), py::arg("context") = py::none(), "Gets a sparse_tensor.encoding from parameters.") .def_property_readonly( @@ -71,6 +72,14 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) { return {}; return ret; }) + .def_property_readonly( + "lvl_to_dim", + [](MlirAttribute self) -> std::optional { + MlirAffineMap ret = mlirSparseTensorEncodingAttrGetLvlToDim(self); + if (mlirAffineMapIsNull(ret)) + return {}; + return ret; + }) .def_property_readonly("pos_width", mlirSparseTensorEncodingAttrGetPosWidth) .def_property_readonly("crd_width", diff --git a/mlir/lib/CAPI/Dialect/SparseTensor.cpp b/mlir/lib/CAPI/Dialect/SparseTensor.cpp index bf3a4ad5e7a168..c3ad95527df489 100644 --- a/mlir/lib/CAPI/Dialect/SparseTensor.cpp +++ b/mlir/lib/CAPI/Dialect/SparseTensor.cpp @@ -48,15 +48,14 @@ bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr) { MlirAttribute mlirSparseTensorEncodingAttrGet(MlirContext ctx, intptr_t lvlRank, MlirSparseTensorDimLevelType const *lvlTypes, - MlirAffineMap dimToLvl, int posWidth, - int crdWidth) { + MlirAffineMap dimToLvl, MlirAffineMap lvlToDim, + int posWidth, int crdWidth) { SmallVector cppLvlTypes; cppLvlTypes.reserve(lvlRank); for (intptr_t l = 0; l < lvlRank; ++l) cppLvlTypes.push_back(static_cast(lvlTypes[l])); - mlir::AffineMap lvlToDim; // TODO: provide in API return wrap(SparseTensorEncodingAttr::get(unwrap(ctx), cppLvlTypes, - unwrap(dimToLvl), lvlToDim, + unwrap(dimToLvl), unwrap(lvlToDim), posWidth, crdWidth)); } diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp index cd1e585438ddac..fd87bbfa905ed6 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -293,9 +293,8 @@ Type SparseTensorEncodingAttr::getCrdType() const { SparseTensorEncodingAttr SparseTensorEncodingAttr::withDimToLvl(AffineMap dimToLvl) const { assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); - // TODO: infer lvlToDim return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(), dimToLvl, - /*lvlToDim*/ AffineMap(), getPosWidth(), + getLvlToDim(), getPosWidth(), getCrdWidth()); } @@ -583,7 +582,8 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) { #undef RETURN_ON_FAIL // Construct struct-like storage for attribute. - AffineMap lvlToDim; // TODO: infer + // TODO: Fetch lvlToDim if user provides one + AffineMap lvlToDim = inferLvlToDim(dimToLvl, parser.getContext()); return parser.getChecked( parser.getContext(), lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth, dimSlices); @@ -749,6 +749,75 @@ mlir::sparse_tensor::getSparseTensorEncoding(Type type) { return nullptr; } +AffineMap mlir::sparse_tensor::inferLvlToDim(AffineMap dimToLvl, + MLIRContext *context) { + auto map = static_cast(dimToLvl); + AffineMap lvlToDim; + // Return an empty lvlToDim when inference is not successful. + if (!map || map.getNumSymbols() != 0) { + lvlToDim = AffineMap(); + } else if (map.isPermutation()) { + lvlToDim = inversePermutation(map); + } else { + // TODO: check if it's block sparsity + lvlToDim = inverseBlockSparsity(map, context); + } + return lvlToDim; +} + +AffineMap mlir::sparse_tensor::inverseBlockSparsity(AffineMap dimToLvl, + MLIRContext *context) { + SmallVector lvlExprs; + auto numLvls = dimToLvl.getNumResults(); + lvlExprs.reserve(numLvls); + // lvlExprComponents stores information of the floordiv and mod operations + // applied to the same dimension, so as to build the lvlToDim map. + std::map> lvlExprComponents; + for (unsigned i = 0, n = numLvls; i < n; i++) { + auto result = dimToLvl.getResult(i); + if (auto binOp = result.dyn_cast()) { + if (result.getKind() == AffineExprKind::FloorDiv) { + // Position of the dimension in dimToLvl. + auto pos = binOp.getLHS().dyn_cast().getPosition(); + assert(lvlExprComponents.find(pos) == lvlExprComponents.end() && + "expected only one floordiv for each dimension"); + SmallVector components; + // Level variable for floordiv. + components.push_back(getAffineDimExpr(i, context)); + // Multiplier. + components.push_back(binOp.getRHS()); + // Map key is the position of the dimension. + lvlExprComponents[pos] = components; + } else if (result.getKind() == AffineExprKind::Mod) { + auto pos = binOp.getLHS().dyn_cast().getPosition(); + assert(lvlExprComponents.find(pos) != lvlExprComponents.end() && + "expected floordiv before mod"); + // Add level variable for mod to the same vector + // of the corresponding floordiv. + lvlExprComponents[pos].push_back(getAffineDimExpr(i, context)); + } else { + assert(false && "expected floordiv or mod"); + } + } else { + lvlExprs.push_back(getAffineDimExpr(i, context)); + } + } + // Build lvlExprs from lvlExprComponents. + // For example, for il = i floordiv 2 and ii = i mod 2, the components + // would be [il, 2, ii]. It could be used to build the AffineExpr + // i = il * 2 + ii in lvlToDim. + for (auto &components : lvlExprComponents) { + assert(components.second.size() == 3 && + "expected 3 components to build lvlExprs"); + auto mulOp = getAffineBinaryOpExpr( + AffineExprKind::Mul, components.second[0], components.second[1]); + auto addOp = + getAffineBinaryOpExpr(AffineExprKind::Add, mulOp, components.second[2]); + lvlExprs.push_back(addOp); + } + return dimToLvl.get(dimToLvl.getNumResults(), 0, lvlExprs, context); +} + bool mlir::sparse_tensor::isCOOType(SparseTensorEncodingAttr enc, Level startLvl, bool isUnique) { if (!enc || @@ -811,7 +880,7 @@ RankedTensorType sparse_tensor::getCOOFromTypeWithOrdering(RankedTensorType rtt, // default value. unsigned posWidth = src.getPosWidth(); unsigned crdWidth = src.getCrdWidth(); - AffineMap invPerm; // TODO + AffineMap invPerm = src.getLvlToDim(); auto enc = SparseTensorEncodingAttr::get(src.getContext(), lvlTypes, lvlPerm, invPerm, posWidth, crdWidth); return RankedTensorType::get(src.getDimShape(), src.getElementType(), enc); diff --git a/mlir/test/CAPI/sparse_tensor.c b/mlir/test/CAPI/sparse_tensor.c index 33ee8e784096a1..3bd1508cf299a3 100644 --- a/mlir/test/CAPI/sparse_tensor.c +++ b/mlir/test/CAPI/sparse_tensor.c @@ -40,6 +40,8 @@ static int testRoundtripEncoding(MlirContext ctx) { // CHECK: level_type: 4 // CHECK: level_type: 8 // CHECK: level_type: 8 + MlirAffineMap lvlToDim = + mlirSparseTensorEncodingAttrGetLvlToDim(originalAttr); int lvlRank = mlirSparseTensorEncodingGetLvlRank(originalAttr); enum MlirSparseTensorDimLevelType *lvlTypes = malloc(sizeof(enum MlirSparseTensorDimLevelType) * lvlRank); @@ -53,9 +55,8 @@ static int testRoundtripEncoding(MlirContext ctx) { // CHECK: crdWidth: 64 int crdWidth = mlirSparseTensorEncodingAttrGetCrdWidth(originalAttr); fprintf(stderr, "crdWidth: %d\n", crdWidth); - // TODO: lvlToDim MlirAttribute newAttr = mlirSparseTensorEncodingAttrGet( - ctx, lvlRank, lvlTypes, dimToLvl, posWidth, crdWidth); + ctx, lvlRank, lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth); mlirAttributeDump(newAttr); // For debugging filecheck output. // CHECK: equal: 1 fprintf(stderr, "equal: %d\n", mlirAttributeEqual(originalAttr, newAttr)); diff --git a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir index ae3805d8b77417..ea8217ab6e3f23 100644 --- a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir @@ -160,6 +160,24 @@ func.func private @BSR(%arg0: tensor) { // ----- +#BCSR = #sparse_tensor.encoding<{ + map = ( i, j, k ) -> + ( i floordiv 2 : dense, + j floordiv 3 : dense, + k floordiv 4 : compressed, + i mod 2 : dense, + j mod 3 : dense, + k mod 4 : dense + ) +}> + +// CHECK-LABEL: func private @BCSR( +// CHECK-SAME: tensor (d0 floordiv 2 : dense, d1 floordiv 3 : dense, d2 floordiv 4 : compressed, d0 mod 2 : dense, d1 mod 3 : dense, d2 mod 4 : dense) }>> +func.func private @BCSR(%arg0: tensor) { + return +} +// ----- + #BSR_explicit = #sparse_tensor.encoding<{ map = {il, jl, ii, jj} @@ -194,3 +212,37 @@ func.func private @BSR_explicit(%arg0: tensor) { func.func private @NV_24(%arg0: tensor) { return } + +// ----- + +#NV_24 = #sparse_tensor.encoding<{ + map = ( i, j, k ) -> + ( i : dense, + j : dense, + k floordiv 4 : dense, + k mod 4 : block2_4 + ) +}> + +// CHECK-LABEL: func private @NV_24( +// CHECK-SAME: tensor (d0 : dense, d1 : dense, d2 floordiv 4 : dense, d2 mod 4 : block2_4) }>> +func.func private @NV_24(%arg0: tensor) { + return +} + +// ----- + +#NV_24 = #sparse_tensor.encoding<{ + map = ( i, j, k ) -> + ( i : dense, + k floordiv 4 : dense, + j : dense, + k mod 4 : block2_4 + ) +}> + +// CHECK-LABEL: func private @NV_24( +// CHECK-SAME: tensor (d0 : dense, d2 floordiv 4 : dense, d1 : dense, d2 mod 4 : block2_4) }>> +func.func private @NV_24(%arg0: tensor) { + return +} \ No newline at end of file diff --git a/mlir/test/Integration/Dialect/SparseTensor/python/test_SDDMM.py b/mlir/test/Integration/Dialect/SparseTensor/python/test_SDDMM.py index 0cdc7c88bd97fb..1f9b6360383180 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/python/test_SDDMM.py +++ b/mlir/test/Integration/Dialect/SparseTensor/python/test_SDDMM.py @@ -155,7 +155,7 @@ def main(): for iwidth in [32]: for e in [True]: attr = st.EncodingAttr.get( - level, ordering, pwidth, iwidth + level, ordering, None, pwidth, iwidth ) opt = f"parallelization-strategy=none" compiler = sparse_compiler.SparseCompiler( diff --git a/mlir/test/Integration/Dialect/SparseTensor/python/test_SpMM.py b/mlir/test/Integration/Dialect/SparseTensor/python/test_SpMM.py index 01d74a4dc82fa1..69f6cdcea967fa 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/python/test_SpMM.py +++ b/mlir/test/Integration/Dialect/SparseTensor/python/test_SpMM.py @@ -145,7 +145,7 @@ def main(): for pwidth in bitwidths: for iwidth in bitwidths: attr = st.EncodingAttr.get( - level, ordering, pwidth, iwidth + level, ordering, None, pwidth, iwidth ) build_compile_and_run_SpMM(attr, compiler) count = count + 1 diff --git a/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py b/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py index 8f3f4e5af1e58e..7d774900802051 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py +++ b/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py @@ -91,7 +91,7 @@ def main(): for level in levels: for ordering in orderings: for bwidth in bitwidths: - attr = st.EncodingAttr.get(level, ordering, bwidth, bwidth) + attr = st.EncodingAttr.get(level, ordering, None, bwidth, bwidth) build_compile_and_run_output(attr, compiler) count = count + 1 diff --git a/mlir/test/Integration/Dialect/SparseTensor/python/test_stress.py b/mlir/test/Integration/Dialect/SparseTensor/python/test_stress.py index ef266672ce42af..841b02bc10c8be 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/python/test_stress.py +++ b/mlir/test/Integration/Dialect/SparseTensor/python/test_stress.py @@ -227,7 +227,7 @@ def main(): for pwidth in bitwidths: for iwidth in bitwidths: attr = st.EncodingAttr.get( - level, ordering, pwidth, iwidth + level, ordering, None, pwidth, iwidth ) types.append(ir.RankedTensorType.get(shape, f64, attr)) # diff --git a/mlir/test/python/dialects/sparse_tensor/dialect.py b/mlir/test/python/dialects/sparse_tensor/dialect.py index d80b878323377a..240db6ebd1d1eb 100644 --- a/mlir/test/python/dialects/sparse_tensor/dialect.py +++ b/mlir/test/python/dialects/sparse_tensor/dialect.py @@ -32,12 +32,14 @@ def testEncodingAttr1D(): print(f"lvl_types: {casted.lvl_types}") # CHECK: dim_to_lvl: None print(f"dim_to_lvl: {casted.dim_to_lvl}") + # CHECK: lvl_to_dim: None + print(f"lvl_to_dim: {casted.lvl_to_dim}") # CHECK: pos_width: 16 print(f"pos_width: {casted.pos_width}") # CHECK: crd_width: 32 print(f"crd_width: {casted.crd_width}") - created = st.EncodingAttr.get(casted.lvl_types, None, 0, 0) + created = st.EncodingAttr.get(casted.lvl_types, None, None, 0, 0) # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }> print(created) # CHECK: created_equal: False @@ -72,12 +74,20 @@ def testEncodingAttr2D(): print(f"lvl_types: {casted.lvl_types}") # CHECK: dim_to_lvl: (d0, d1) -> (d1, d0) print(f"dim_to_lvl: {casted.dim_to_lvl}") + # CHECK: lvl_to_dim: (d0, d1) -> (d1, d0) + print(f"lvl_to_dim: {casted.lvl_to_dim}") # CHECK: pos_width: 8 print(f"pos_width: {casted.pos_width}") # CHECK: crd_width: 32 print(f"crd_width: {casted.crd_width}") - created = st.EncodingAttr.get(casted.lvl_types, casted.dim_to_lvl, 8, 32) + created = st.EncodingAttr.get( + casted.lvl_types, + casted.dim_to_lvl, + casted.lvl_to_dim, + 8, + 32, + ) # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d1 : dense, d0 : compressed), posWidth = 8, crdWidth = 32 }> print(created) # CHECK: created_equal: True