Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WORK IN PROGRESS] tcp.scatter and tcp.scatter_nd #103

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions include/mlir-tcp/Dialect/IR/TcpOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,83 @@ def Tcp_GatherOp : Tcp_Op<"gather", [Pure, AllElementTypesMatch<["input", "out"]
let hasVerifier = 1;
}

def Tcp_GatherNDOp : Tcp_Op<"gather_nd", [Pure, AllElementTypesMatch<["input", "out"]>]> {

let summary = "Gather elements from input based on indices over multiple dimentions";

let description = [{
Gathers elements from a given tensor based on indices that index along multiple dimensions.

More details regarding this op: docs/gather.md
}];

let arguments = (ins
Tcp_Tensor:$input,
Tcp_IntTensor:$indices
);

let results = (outs
Tcp_Tensor:$out
);

let assemblyFormat = "$input `,` $indices attr-dict `:` type($input) `,` type($indices) `->` type($out)";

let hasVerifier = 1;
}

def Tcp_ScatterOp : Tcp_Op<"scatter", [Pure, AllElementTypesMatch<["input", "values", "out"]>]> {

let summary = "Scatter elements from values into input based on indices";

let description = [{
Scatter elements from values into the input tensortensor based on indices that index along a given dim.

The indexing of scatter is similar to gather, which is documented in docs/gather.md
}];

let arguments = (ins
Tcp_Tensor:$input,
Tcp_IntTensor:$indices,
Tcp_Tensor:$values,
IndexAttr:$dim
);

let results = (outs
Tcp_Tensor:$out
);

let assemblyFormat = "$input `,` $indices `,` $values `,` attr-dict `:` type($input) `,` type($indices) `,` type($values) `->` type($out)";

let hasVerifier = 1;

}

def Tcp_ScatterNDOp : Tcp_Op<"scatter_nd", [Pure, AllElementTypesMatch<["input", "values", "out"]>]> {

let summary = "Scatter elements from values over the input based on indices over multiple dimensions";

let description = [{
Scatter elements from the values tensor over the input tensor according to the indcies tensor
along multiple dimensions.

Note: the shape of the indicies is similar to that of gather documented in docs/gather.md
}];

let arguments = (ins
Tcp_Tensor:$input,
Tcp_IntTensor:$indices,
Tcp_Tensor:$values
);

let results = (outs
Tcp_Tensor:$out
);

let assemblyFormat = "$input `,` $indices `,` $values attr-dict `:` type($input) `,` type($indices) `,` type($values) `->` type($out)";

let hasVerifier = 1;
}

def Tcp_SliceOp : Tcp_Op<"slice", [Pure, AllElementTypesMatch<["in", "out"]>, SameVariadicOperandSize]> {

let summary = "Extracts a slice of the input tensor";
Expand Down
98 changes: 98 additions & 0 deletions lib/Conversion/TcpToLinalg/DataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,102 @@ class ConvertGatherOp : public OpConversionPattern<GatherOp> {
}
};

/**
* tcp.gather_nd is lowered to linalg.generic, which allows us to define every
* element in the result tensor using a programmatic expression. The last
* dimension of the indicies tensor is used to index into the input tensor.
*
* For example, we we have an indices tensor of shape 9x4x3x2 and an input
* tensor of shape 5x6x7x8, then the resulting tensor will be of shape
* 9x4x3x7x8. Where the first three dimensions of the resulting tensor are used
* to index into the indicies tensor. Then the last dimension of the index
* tensor (the 2 sized dimension) is used to index into the input tensor.
*/
class ConvertGatherNDOp : public OpConversionPattern<GatherNDOp> {
public:
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(GatherNDOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
auto resultTensorType = getTypeConverter()
->convertType(op.getOut().getType())
.cast<RankedTensorType>();

auto inputTensor = adaptor.getInput();
auto indicesTensor = adaptor.getIndices();
auto indicesType = cast<RankedTensorType>(indicesTensor.getType());
auto inputType = cast<RankedTensorType>(inputTensor.getType());
int numGatherAxes = indicesType.getShape().back();

SmallVector<Value> resultDimSizes;
for (int i = 0; i < indicesType.getRank() - 1; i++) {
resultDimSizes.push_back(
rewriter.createOrFold<tensor::DimOp>(loc, indicesTensor, i));
}
for (int i = numGatherAxes; i < inputType.getRank(); i++) {
resultDimSizes.push_back(
rewriter.createOrFold<tensor::DimOp>(loc, inputTensor, i));
}

assert(resultDimSizes.size() == resultTensorType.getRank());

Value emptyTensor =
rewriter.create<tensor::EmptyOp>(loc, getAsOpFoldResult(resultDimSizes),
resultTensorType.getElementType());

auto bodyBuilder = [&](OpBuilder &b, Location loc, ValueRange payloadArgs) {
SmallVector<Value> valueIndices, gatherIndices;
for (int i = 0; i < indicesType.getRank() - 1; i++) {
auto idx = b.create<linalg::IndexOp>(loc, b.getIndexType(),
b.getI64IntegerAttr(i));
gatherIndices.push_back(idx);
}
for (int i = 0; i < numGatherAxes; i++) {
SmallVector<Value> gi = gatherIndices;
auto gidx = b.create<arith::ConstantOp>(loc, b.getIndexAttr(i));
gi.push_back(gidx);
assert(gi.size() == indicesType.getRank());
auto idxExtract = b.create<tensor::ExtractOp>(
loc, indicesType.getElementType(), indicesTensor, gi);
auto idxCast =
b.create<arith::IndexCastOp>(loc, b.getIndexType(), idxExtract);
valueIndices.push_back(idxCast);
}
for (int i = indicesType.getRank() - 1; i < resultTensorType.getRank();
i++) {
auto idx = b.create<linalg::IndexOp>(loc, b.getIndexType(),
b.getI64IntegerAttr(i));
valueIndices.push_back(idx);
}
assert(valueIndices.size() == inputType.getRank());
auto extract =
b.create<tensor::ExtractOp>(loc, resultTensorType.getElementType(),
inputTensor, valueIndices)
.getResult();

b.create<linalg::YieldOp>(loc, extract);
};

SmallVector<Value> empty;
SmallVector<AffineMap> indexingMaps;
indexingMaps.push_back(
rewriter.getMultiDimIdentityMap(resultTensorType.getRank()));
SmallVector<utils::IteratorType> iteratorTypes(
resultTensorType.getRank(), utils::IteratorType::parallel);

auto generic = rewriter.create<linalg::GenericOp>(
loc, resultTensorType, empty, emptyTensor, indexingMaps, iteratorTypes,
bodyBuilder);

rewriter.replaceOp(op, generic.getResult(0));

return success();
}
};


} // namespace

void mlir::TcpToLinalg::populateDataMovementPatternsAndLegality(
Expand All @@ -100,4 +196,6 @@ void mlir::TcpToLinalg::populateDataMovementPatternsAndLegality(

target.addIllegalOp<GatherOp>();
patterns.add<ConvertGatherOp>(typeConverter, context);
target.addIllegalOp<GatherNDOp>();
patterns.add<ConvertGatherNDOp>(typeConverter, context);
}
33 changes: 33 additions & 0 deletions lib/Conversion/TcpToTensor/TcpToTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,44 @@ class SliceOpConverter : public OpConversionPattern<tcp::SliceOp> {
}
};

/**
* This lowers tcp.scatter_nd to tensor.scatter.
* However, tensor.scatter currently does not have anything that it can lower to, so it
* then fails to generate code
*/
class ConvertScatterNDOp : public OpConversionPattern<ScatterNDOp> {
public:
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(ScatterNDOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

Value indicesIn = adaptor.getIndices();
auto indicesType = cast<RankedTensorType>(indicesIn.getType());
auto indices = rewriter.create<arith::IndexCastOp>(op.getLoc(),
RankedTensorType::get(indicesType.getShape(), rewriter.getIndexType()),
adaptor.getIndices()).getResult();
int numIndices = cast<RankedTensorType>(indicesIn.getType()).getShape().back();
SmallVector<int64_t> scatterDims;
for(int i = 0; i < numIndices; i++) scatterDims.push_back(i);
rewriter.replaceOpWithNewOp<tensor::ScatterOp>(
op, op.getType(), adaptor.getValues(), adaptor.getInput(), indices,
scatterDims, true
);
return success();
}

};

void populateTcpToTensorPatternsAndLegality(RewritePatternSet &patterns,
ConversionTarget &target) {
MLIRContext *context = patterns.getContext();

target.addIllegalOp<tcp::SliceOp>();
patterns.add<SliceOpConverter>(context);
target.addIllegalOp<tcp::ScatterNDOp>();
patterns.add<ConvertScatterNDOp>(context);
}

class ConvertTcpToTensor : public ConvertTcpToTensorBase<ConvertTcpToTensor> {
Expand All @@ -61,6 +93,7 @@ class ConvertTcpToTensor : public ConvertTcpToTensorBase<ConvertTcpToTensor> {
MLIRContext *context = &getContext();
ConversionTarget target(*context);
target.addLegalDialect<mlir::tensor::TensorDialect>();
target.addLegalDialect<mlir::arith::ArithDialect>();

RewritePatternSet patterns(context);
populateTcpToTensorPatternsAndLegality(patterns, target);
Expand Down
Loading
Loading