Skip to content

Commit

Permalink
[mlir][sparse] rename sparse_tensor.(un)pack to sparse_tensor.(dis)as… (
Browse files Browse the repository at this point in the history
#67717)

…semble

Pack/Unpack are overridden in many other places, rename the operations
to avoid confusion.
  • Loading branch information
PeimingLiu authored Sep 28, 2023
1 parent 9f2fc88 commit 6ca47eb
Show file tree
Hide file tree
Showing 13 changed files with 62 additions and 58 deletions.
12 changes: 6 additions & 6 deletions mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,14 @@ def SparseTensor_NewOp : SparseTensor_Op<"new", [Pure]>,
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($result)";
}

def SparseTensor_PackOp : SparseTensor_Op<"pack", [Pure]>,
def SparseTensor_AssembleOp : SparseTensor_Op<"assemble", [Pure]>,
Arguments<(ins TensorOf<[AnyType]>:$values,
Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$levels)>,
Results<(outs AnySparseTensor: $result)> {
let summary = "Returns a sparse tensor from the given values, levels";

let description = [{
Packs the values and per-level coordinate or postion arrays into a sparse tensor.
Assembles the values and per-level coordinate or postion arrays into a sparse tensor.
The order and types of provided levels must be consistent with the actual storage
layout of the returned sparse tensor described below.

Expand All @@ -87,7 +87,7 @@ def SparseTensor_PackOp : SparseTensor_Op<"pack", [Pure]>,
```mlir
%values = arith.constant dense<[ 1.1, 2.2, 3.3 ]> : tensor<3xf64>
%coordinates = arith.constant dense<[[0,0], [1,2], [1,3]]> : tensor<3x2xindex>
%st = sparse_tensor.pack %values, %coordinates
%st = sparse_tensor.assemble %values, %coordinates
: tensor<3xf64>, tensor<3x2xindex> to tensor<3x4xf64, #COO>
// yields COO format |1.1, 0.0, 0.0, 0.0|
// of 3x4 matrix |0.0, 0.0, 2.2, 3.3|
Expand All @@ -102,7 +102,7 @@ def SparseTensor_PackOp : SparseTensor_Op<"pack", [Pure]>,
let hasVerifier = 1;
}

def SparseTensor_UnpackOp : SparseTensor_Op<"unpack", [Pure, SameVariadicResultSize]>,
def SparseTensor_DisassembleOp : SparseTensor_Op<"disassemble", [Pure, SameVariadicResultSize]>,
Arguments<(ins AnySparseTensor:$tensor,
TensorOf<[AnyType]>:$out_values,
Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$out_levels)>,
Expand All @@ -113,7 +113,7 @@ def SparseTensor_UnpackOp : SparseTensor_Op<"unpack", [Pure, SameVariadicResultS
let summary = "Returns the (values, coordinates) pair unpacked from the input tensor";

let description = [{
The unpack operation is the inverse of `sparse_tensor::pack`. It returns
The disassemble operation is the inverse of `sparse_tensor::assemble`. It returns
the values and per-level position and coordinate array to the user
from the sparse tensor along with the actual length of the memory used in
each returned buffer. This operation can be used for returning an
Expand All @@ -132,7 +132,7 @@ def SparseTensor_UnpackOp : SparseTensor_Op<"unpack", [Pure, SameVariadicResultS
// of 3x4 matrix |0.0, 0.0, 2.2, 3.3|
// |0.0, 0.0, 0.0, 0.0|
%v, %p, %c, %v_len, %p_len, %c_len =
sparse_tensor.unpack %sp : tensor<3x4xf64, #COO>
sparse_tensor.disassemble %sp : tensor<3x4xf64, #COO>
outs(%od, %op, %oi : tensor<3xf64>, tensor<2xindex>, tensor<3x2xindex>)
-> tensor<3xf64>, (tensor<2xindex>, tensor<3x2xindex>), index, (index, index)
// %v = arith.constant dense<[ 1.1, 2.2, 3.3 ]> : tensor<3xf64>
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -974,14 +974,14 @@ static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
return success();
}

LogicalResult PackOp::verify() {
LogicalResult AssembleOp::verify() {
const auto valuesTp = getRankedTensorType(getValues());
const auto lvlsTp = getLevels().getTypes();
const auto resTp = getSparseTensorType(getResult());
return verifyPackUnPack(*this, true, resTp, valuesTp, lvlsTp);
}

LogicalResult UnpackOp::verify() {
LogicalResult DisassembleOp::verify() {
if (getOutValues().getType() != getRetValues().getType())
return emitError("output values and return value type mismatch");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,11 @@ struct NewOpInterface
bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
};

struct PackOpInterface
: public SparseBufferizableOpInterfaceExternalModel<PackOpInterface,
sparse_tensor::PackOp> {
struct AssembleOpInterface
: public SparseBufferizableOpInterfaceExternalModel<
AssembleOpInterface, sparse_tensor::AssembleOp> {
bool bufferizesToAllocation(Operation *op, Value value) const {
// PackOp reuses all the buffers instead of allocating new ones
// AssembleOp reuses all the buffers instead of allocating new ones
return false;
}

Expand All @@ -143,7 +143,7 @@ struct PackOpInterface
AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
assert(op->getNumResults() == 1);
// PackOp reuses the input tensors as values/coordinates instead of
// AssembleOp reuses the input tensors as values/coordinates instead of
// creating new ones when packing into a COO format.
return {{op->getOpResult(0), BufferRelation::Equivalent}};
}
Expand All @@ -154,8 +154,9 @@ struct PackOpInterface
}
};

struct UnpackOpInterface : public SparseBufferizableOpInterfaceExternalModel<
UnpackOpInterface, sparse_tensor::UnpackOp> {
struct DisassembleOpInterface
: public SparseBufferizableOpInterfaceExternalModel<
DisassembleOpInterface, sparse_tensor::DisassembleOp> {
bool bufferizesToAllocation(Operation *op, Value value) const {
// The output buffer is pre-allocated by the user.
return false;
Expand Down Expand Up @@ -326,8 +327,8 @@ void mlir::sparse_tensor::registerBufferizableOpInterfaceExternalModels(
sparse_tensor::InsertOp::attachInterface<InsertOpInterface>(*ctx);
sparse_tensor::NumberOfEntriesOp::attachInterface<
NumberOfEntriesOpInterface>(*ctx);
sparse_tensor::PackOp::attachInterface<PackOpInterface>(*ctx);
sparse_tensor::UnpackOp::attachInterface<UnpackOpInterface>(*ctx);
sparse_tensor::AssembleOp::attachInterface<AssembleOpInterface>(*ctx);
sparse_tensor::DisassembleOp::attachInterface<DisassembleOpInterface>(*ctx);
sparse_tensor::ToCoordinatesBufferOp::attachInterface<
ToCoordinatesBufferOpInterface>(*ctx);
sparse_tensor::ToCoordinatesOp::attachInterface<ToCoordinatesOpInterface>(
Expand Down
7 changes: 4 additions & 3 deletions mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -795,10 +795,10 @@ rewriteSpGEMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
Value rowC = e1.getResult(0);
token = e1.getAsyncToken();
auto e2 = genAllocBuffer(rewriter, loc, cTp.getCrdType(), zero, token);
Value colC = e2.getResult(0); // no free needed
Value colC = e2.getResult(0); // no free needed
token = e2.getAsyncToken();
auto e3 = genAllocBuffer(rewriter, loc, dnCType, zero, token);
Value valC = e3.getResult(0); // no free needed
Value valC = e3.getResult(0); // no free needed
token = e3.getAsyncToken();
Operation *spGenC =
genSpMat(rewriter, loc, spmatHandleTp, tokenTp, token, szm, szn, zero,
Expand Down Expand Up @@ -900,7 +900,8 @@ rewriteSpGEMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
Value vt = rewriter.create<bufferization::ToTensorOp>(loc, valH);
Value rt = rewriter.create<bufferization::ToTensorOp>(loc, rowH);
Value ct = rewriter.create<bufferization::ToTensorOp>(loc, colH);
rewriter.replaceOpWithNewOp<PackOp>(op, c.getType(), vt, ValueRange{rt, ct});
rewriter.replaceOpWithNewOp<AssembleOp>(op, c.getType(), vt,
ValueRange{rt, ct});
return success();
}

Expand Down
14 changes: 8 additions & 6 deletions mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1244,10 +1244,10 @@ class SparseNumberOfEntriesConverter
}
};

struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(PackOp op, OpAdaptor adaptor,
matchAndRewrite(AssembleOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
const auto stt = getSparseTensorType(op.getResult());
Expand Down Expand Up @@ -1347,13 +1347,15 @@ struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
}
};

struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
struct SparseDisassembleOpConverter
: public OpConversionPattern<DisassembleOp> {
using OpConversionPattern::OpConversionPattern;
SparseUnpackOpConverter(TypeConverter &typeConverter, MLIRContext *context)
SparseDisassembleOpConverter(TypeConverter &typeConverter,
MLIRContext *context)
: OpConversionPattern(typeConverter, context) {}

LogicalResult
matchAndRewrite(UnpackOp op, OpAdaptor adaptor,
matchAndRewrite(DisassembleOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
Location loc = op.getLoc();
Expand Down Expand Up @@ -1571,7 +1573,7 @@ struct SparseNewOpConverter : public OpConversionPattern<NewOp> {
void mlir::populateSparseTensorCodegenPatterns(
TypeConverter &typeConverter, RewritePatternSet &patterns,
bool createSparseDeallocs, bool enableBufferInitialization) {
patterns.add<SparsePackOpConverter, SparseUnpackOpConverter,
patterns.add<SparseAssembleOpConverter, SparseDisassembleOpConverter,
SparseReturnConverter, SparseCallConverter, SparseDimOpConverter,
SparseCastConverter, SparseExtractSliceConverter,
SparseTensorLoadConverter, SparseExpandConverter,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1493,15 +1493,15 @@ class SparseTensorOutConverter : public OpConversionPattern<OutOp> {
};

/// Sparse conversion rule for the sparse_tensor.pack operator.
class SparseTensorPackConverter : public OpConversionPattern<PackOp> {
class SparseTensorAssembleConverter : public OpConversionPattern<AssembleOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(PackOp op, OpAdaptor adaptor,
matchAndRewrite(AssembleOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
const Location loc = op->getLoc();
const auto dstTp = getSparseTensorType(op.getResult());
// PackOps always returns a static shaped tensor result.
// AssembleOps always returns a static shaped tensor result.
assert(dstTp.hasStaticDimShape());
SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, dstTp);
Value dst =
Expand Down Expand Up @@ -1546,7 +1546,7 @@ void mlir::populateSparseTensorConversionPatterns(
SparseTensorToValuesConverter, SparseNumberOfEntriesConverter,
SparseTensorLoadConverter, SparseTensorInsertConverter,
SparseTensorExpandConverter, SparseTensorCompressConverter,
SparseTensorOutConverter, SparseTensorPackConverter>(
SparseTensorOutConverter, SparseTensorAssembleConverter>(
typeConverter, patterns.getContext());
patterns.add<SparseTensorConvertConverter>(typeConverter,
patterns.getContext(), options);
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/SparseTensor/GPU/gpu_spgemm_lib.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@
// CHECK: %[[VAL_a2:.*]] = bufferization.to_tensor %[[VAL_83]] : memref<?xf32>
// CHECK: %[[VAL_a3:.*]] = bufferization.to_tensor %[[VAL_81]] : memref<?xindex>
// CHECK: %[[VAL_a4:.*]] = bufferization.to_tensor %[[VAL_82]] : memref<?xindex>
// CHECK: %[[VAL_a5:.*]] = sparse_tensor.pack %[[VAL_a2]], %[[VAL_a3]], %[[VAL_a4]] : tensor<?xf32>, tensor<?xindex>, tensor<?xindex> to tensor<8x8xf32, #{{.*}}>
// CHECK: %[[VAL_a5:.*]] = sparse_tensor.assemble %[[VAL_a2]], %[[VAL_a3]], %[[VAL_a4]] : tensor<?xf32>, tensor<?xindex>, tensor<?xindex> to tensor<8x8xf32, #{{.*}}>
// CHECK: return %[[VAL_a5]] : tensor<8x8xf32, #{{.*}}>
// CHECK: }
func.func @matmulCSR(%A: tensor<8x8xf32, #CSR>,
Expand Down
14 changes: 7 additions & 7 deletions mlir/test/Dialect/SparseTensor/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ func.func @invalid_new_dense(%arg0: !llvm.ptr<i8>) -> tensor<32xf32> {
func.func @non_static_pack_ret(%values: tensor<6xf64>, %pos: tensor<2xi32>, %coordinates: tensor<6x1xi32>)
-> tensor<?xf64, #SparseVector> {
// expected-error@+1 {{the sparse-tensor must have static shape}}
%0 = sparse_tensor.pack %values, %pos, %coordinates
%0 = sparse_tensor.assemble %values, %pos, %coordinates
: tensor<6xf64>, tensor<2xi32>, tensor<6x1xi32> to tensor<?xf64, #SparseVector>
return %0 : tensor<?xf64, #SparseVector>
}
Expand All @@ -25,7 +25,7 @@ func.func @non_static_pack_ret(%values: tensor<6xf64>, %pos: tensor<2xi32>, %coo
func.func @invalid_pack_type(%values: tensor<6xf64>, %pos: tensor<2xi32>, %coordinates: tensor<6x1xi32>)
-> tensor<100xf32, #SparseVector> {
// expected-error@+1 {{input/output element-types don't match}}
%0 = sparse_tensor.pack %values, %pos, %coordinates
%0 = sparse_tensor.assemble %values, %pos, %coordinates
: tensor<6xf64>, tensor<2xi32>, tensor<6x1xi32> to tensor<100xf32, #SparseVector>
return %0 : tensor<100xf32, #SparseVector>
}
Expand All @@ -37,7 +37,7 @@ func.func @invalid_pack_type(%values: tensor<6xf64>, %pos: tensor<2xi32>, %coord
func.func @invalid_pack_type(%values: tensor<6xf64>, %pos: tensor<2xi32>, %coordinates: tensor<6x3xi32>)
-> tensor<100x2xf64, #SparseVector> {
// expected-error@+1 {{input/output trailing COO level-ranks don't match}}
%0 = sparse_tensor.pack %values, %pos, %coordinates
%0 = sparse_tensor.assemble %values, %pos, %coordinates
: tensor<6xf64>, tensor<2xi32>, tensor<6x3xi32> to tensor<100x2xf64, #SparseVector>
return %0 : tensor<100x2xf64, #SparseVector>
}
Expand All @@ -49,7 +49,7 @@ func.func @invalid_pack_type(%values: tensor<6xf64>, %pos: tensor<2xi32>, %coord
func.func @invalid_pack_mis_position(%values: tensor<6xf64>, %coordinates: tensor<6xi32>)
-> tensor<2x100xf64, #CSR> {
// expected-error@+1 {{inconsistent number of fields between input/output}}
%0 = sparse_tensor.pack %values, %coordinates
%0 = sparse_tensor.assemble %values, %coordinates
: tensor<6xf64>, tensor<6xi32> to tensor<2x100xf64, #CSR>
return %0 : tensor<2x100xf64, #CSR>
}
Expand All @@ -60,7 +60,7 @@ func.func @invalid_pack_mis_position(%values: tensor<6xf64>, %coordinates: tenso

func.func @invalid_unpack_type(%sp: tensor<100xf32, #SparseVector>, %values: tensor<6xf64>, %pos: tensor<2xi32>, %coordinates: tensor<6x1xi32>) {
// expected-error@+1 {{input/output element-types don't match}}
%rv, %rp, %rc, %vl, %pl, %cl = sparse_tensor.unpack %sp : tensor<100xf32, #SparseVector>
%rv, %rp, %rc, %vl, %pl, %cl = sparse_tensor.disassemble %sp : tensor<100xf32, #SparseVector>
outs(%values, %pos, %coordinates : tensor<6xf64>, tensor<2xi32>, tensor<6x1xi32>)
-> tensor<6xf64>, (tensor<2xi32>, tensor<6x1xi32>), index, (index, index)
return
Expand All @@ -72,7 +72,7 @@ func.func @invalid_unpack_type(%sp: tensor<100xf32, #SparseVector>, %values: ten

func.func @invalid_unpack_type(%sp: tensor<100x2xf64, #SparseVector>, %values: tensor<6xf64>, %pos: tensor<2xi32>, %coordinates: tensor<6x3xi32>) {
// expected-error@+1 {{input/output trailing COO level-ranks don't match}}
%rv, %rp, %rc, %vl, %pl, %cl = sparse_tensor.unpack %sp : tensor<100x2xf64, #SparseVector>
%rv, %rp, %rc, %vl, %pl, %cl = sparse_tensor.disassemble %sp : tensor<100x2xf64, #SparseVector>
outs(%values, %pos, %coordinates : tensor<6xf64>, tensor<2xi32>, tensor<6x3xi32>)
-> tensor<6xf64>, (tensor<2xi32>, tensor<6x3xi32>), index, (index, index)
return
Expand All @@ -84,7 +84,7 @@ func.func @invalid_unpack_type(%sp: tensor<100x2xf64, #SparseVector>, %values: t

func.func @invalid_unpack_mis_position(%sp: tensor<2x100xf64, #CSR>, %values: tensor<6xf64>, %coordinates: tensor<6xi32>) {
// expected-error@+1 {{inconsistent number of fields between input/output}}
%rv, %rc, %vl, %pl = sparse_tensor.unpack %sp : tensor<2x100xf64, #CSR>
%rv, %rc, %vl, %pl = sparse_tensor.disassemble %sp : tensor<2x100xf64, #CSR>
outs(%values, %coordinates : tensor<6xf64>, tensor<6xi32>)
-> tensor<6xf64>, (tensor<6xi32>), index, (index)
return
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Dialect/SparseTensor/pack_copy.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func.func @foo(%arg0: tensor<3xf64> {bufferization.writable = false},
//
// Pack the buffers into a sparse tensors.
//
%pack = sparse_tensor.pack %arg0, %arg2, %arg1
%pack = sparse_tensor.assemble %arg0, %arg2, %arg1
: tensor<3xf64>,
tensor<11xi32>,
tensor<3xi32> to tensor<10x10xf64, #CSR>
Expand Down Expand Up @@ -76,7 +76,7 @@ func.func @bar(%arg0: tensor<3xf64> {bufferization.writable = true},
//
// Pack the buffers into a sparse tensors.
//
%pack = sparse_tensor.pack %arg0, %arg2, %arg1
%pack = sparse_tensor.assemble %arg0, %arg2, %arg1
: tensor<3xf64>,
tensor<11xi32>,
tensor<3xi32> to tensor<10x10xf64, #CSR>
Expand Down
8 changes: 4 additions & 4 deletions mlir/test/Dialect/SparseTensor/roundtrip.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ func.func @sparse_new(%arg0: !llvm.ptr<i8>) -> tensor<128xf64, #SparseVector> {
// CHECK-SAME: %[[D:.*]]: tensor<6xf64>,
// CHECK-SAME: %[[P:.*]]: tensor<2xi32>,
// CHECK-SAME: %[[I:.*]]: tensor<6x1xi32>)
// CHECK: %[[R:.*]] = sparse_tensor.pack %[[D]], %[[P]], %[[I]]
// CHECK: %[[R:.*]] = sparse_tensor.assemble %[[D]], %[[P]], %[[I]]
// CHECK: return %[[R]] : tensor<100xf64, #{{.*}}>
func.func @sparse_pack(%data: tensor<6xf64>, %pos: tensor<2xi32>, %index: tensor<6x1xi32>)
-> tensor<100xf64, #SparseVector> {
%0 = sparse_tensor.pack %data, %pos, %index : tensor<6xf64>, tensor<2xi32>, tensor<6x1xi32>
%0 = sparse_tensor.assemble %data, %pos, %index : tensor<6xf64>, tensor<2xi32>, tensor<6x1xi32>
to tensor<100xf64, #SparseVector>
return %0 : tensor<100xf64, #SparseVector>
}
Expand All @@ -36,14 +36,14 @@ func.func @sparse_pack(%data: tensor<6xf64>, %pos: tensor<2xi32>, %index: tensor
// CHECK-SAME: %[[OD:.*]]: tensor<6xf64>
// CHECK-SAME: %[[OP:.*]]: tensor<2xindex>
// CHECK-SAME: %[[OI:.*]]: tensor<6x1xi32>
// CHECK: %[[D:.*]], %[[P:.*]]:2, %[[DL:.*]], %[[PL:.*]]:2 = sparse_tensor.unpack %[[T]]
// CHECK: %[[D:.*]], %[[P:.*]]:2, %[[DL:.*]], %[[PL:.*]]:2 = sparse_tensor.disassemble %[[T]]
// CHECK: return %[[D]], %[[P]]#0, %[[P]]#1
func.func @sparse_unpack(%sp : tensor<100xf64, #SparseVector>,
%od : tensor<6xf64>,
%op : tensor<2xindex>,
%oi : tensor<6x1xi32>)
-> (tensor<6xf64>, tensor<2xindex>, tensor<6x1xi32>) {
%rd, %rp, %ri, %vl, %pl, %cl = sparse_tensor.unpack %sp : tensor<100xf64, #SparseVector>
%rd, %rp, %ri, %vl, %pl, %cl = sparse_tensor.disassemble %sp : tensor<100xf64, #SparseVector>
outs(%od, %op, %oi : tensor<6xf64>, tensor<2xindex>, tensor<6x1xi32>)
-> tensor<6xf64>, (tensor<2xindex>, tensor<6x1xi32>), index, (index, index)
return %rd, %rp, %ri : tensor<6xf64>, tensor<2xindex>, tensor<6x1xi32>
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Dialect/SparseTensor/sparse_pack.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
// CHECK: }
func.func @sparse_pack(%values: tensor<6xf64>, %pos:tensor<2xindex>, %coordinates: tensor<6x2xi32>)
-> tensor<100x100xf64, #COO> {
%0 = sparse_tensor.pack %values, %pos, %coordinates
%0 = sparse_tensor.assemble %values, %pos, %coordinates
: tensor<6xf64>, tensor<2xindex>, tensor<6x2xi32> to tensor<100x100xf64, #COO>
return %0 : tensor<100x100xf64, #COO>
}
Expand Down Expand Up @@ -70,7 +70,7 @@ func.func @sparse_unpack(%sp : tensor<100x100xf64, #COO>,
%op : tensor<2xindex>,
%oi : tensor<6x2xi32>)
-> (tensor<6xf64>, tensor<2xindex>, tensor<6x2xi32>) {
%rd, %rp, %ri, %dl, %pl, %il = sparse_tensor.unpack %sp : tensor<100x100xf64, #COO>
%rd, %rp, %ri, %dl, %pl, %il = sparse_tensor.disassemble %sp : tensor<100x100xf64, #COO>
outs(%od, %op, %oi : tensor<6xf64>, tensor<2xindex>, tensor<6x2xi32>)
-> tensor<6xf64>, (tensor<2xindex>, tensor<6x2xi32>), index, (index, index)
return %rd, %rp, %ri : tensor<6xf64>, tensor<2xindex>, tensor<6x2xi32>
Expand Down
Loading

0 comments on commit 6ca47eb

Please sign in to comment.