Skip to content

Commit

Permalink
combine slice and concate to new Rope ConcatToRope
Browse files Browse the repository at this point in the history
Change-Id: Ib15b12fe97117b96c6fe7267c96c3f714aac6ec4
  • Loading branch information
xieminghe1 authored and Silence-Zhang-beijng committed Oct 25, 2024
1 parent 94b8ceb commit 85c9177
Showing 1 changed file with 102 additions and 1 deletion.
103 changes: 102 additions & 1 deletion lib/Dialect/Top/Canonicalize/Concat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,107 @@ struct ConcatToDepth2SpacePattern : public OpRewriterPatternEx<ConcatOp> {
}
};

struct ConcatToRope : public OpRewriterPatternEx<ConcatOp> {
using OpRewriterPatternEx::OpRewriterPatternEx;

ConcatToRope(mlir::MLIRContext *context)
: OpRewriterPatternEx<ConcatOp>(context, "ConcatToRope") {}

LogicalResult matchAndRewriteImpl(ConcatOp op,
PatternRewriter &rewriter) const override {
if (op.getInputs().size() != 2) {
return failure();
}
int indx = 0;
for (int i = 0; i < 2; ++i) {
auto rope_op = dyn_cast<RopeOp>(op.getInputs()[i].getDefiningOp());
if(rope_op){
indx = i;
}else{
indx = 1 - i;
}
}
auto rope_op = dyn_cast<RopeOp>(op.getInputs()[indx].getDefiningOp());
auto slice0_op = dyn_cast<SliceOp>(op.getInputs()[1-indx].getDefiningOp());
if (!rope_op || !slice0_op) {
return failure();
}
auto slice1_op = dyn_cast<SliceOp>(rope_op.getInput1().getDefiningOp());
if(!slice1_op){
return failure();
}
Value in_value;
auto weight0 = rope_op.getInput2();
auto weight1 = rope_op.getInput3();
auto weight_shape = module::getShape(weight0);
auto W0 = dyn_cast<WeightOp>(weight0.getDefiningOp());
auto W1 = dyn_cast<WeightOp>(weight1.getDefiningOp());
if (!W0 || !W1) {
return failure();
}
auto left_weight = *(W0.read_as_float());
auto right_weight = *(W1.read_as_float());
std::vector<std::vector<std::vector<std::vector<float>>>> new_weight0(
weight_shape[0],
std::vector<std::vector<std::vector<float>>>(
weight_shape[1],
std::vector<std::vector<float>>(
weight_shape[2] + 1, std::vector<float>(weight_shape[3]))));
std::vector<std::vector<std::vector<std::vector<float>>>> new_weight1(weight_shape[0], std::vector<std::vector<std::vector<float>>>(weight_shape[1], std::vector<std::vector<float>>(weight_shape[2]+1, std::vector<float>(weight_shape[3]))));

std::vector<float> new_w0((weight_shape[2]+1)*weight_shape[3]);
std::vector<float> new_w1((weight_shape[2]+1)*weight_shape[3]);

for (int j = 0; j < weight_shape[3]; j++) {
new_weight0[0][0][0][j] = 0.0f;
new_weight1[0][0][0][j] = 1.0f;
}
int cnt = 0;
for (int i = 0; i < weight_shape[2]; i++) {
for (int j = 0; j < weight_shape[3]; j++) {
new_weight0[0][0][i + 1][j] = left_weight[cnt];
new_weight1[0][0][i + 1][j] = right_weight[cnt];
cnt += 1;
}
}

int count = 0;
for(int i=0;i<weight_shape[2]+1;i++){
for(int j=0;j<weight_shape[3];j++){
new_w0[count] = new_weight0[0][0][i][j];
new_w1[count] = new_weight1[0][0][i][j];
count += 1;
}
}


auto storage_type = module::getStorageType(op.getOutput());
if (!storage_type.isF32() && !storage_type.isF16()) {
return failure();
}

std::vector<int64_t> new_weight_shape = {weight_shape[0], weight_shape[1], weight_shape[2] + 1,
weight_shape[3]};

auto new_Weight0 = WeightOp::create_float(op, "weight0", new_w0,
new_weight_shape, storage_type);
auto new_Weight1 = WeightOp::create_float(op, "weight1", new_w1,
new_weight_shape, storage_type);

if (slice0_op.getInput().getDefiningOp() ==
slice1_op.getInput().getDefiningOp()
) {
in_value = slice0_op.getInput();
}
else{
return failure();
}
std::vector<NamedAttribute> attrs;
rewriter.replaceOpWithNewOp<RopeOp>(op, op.getResult().getType(),
ValueRange{in_value, new_Weight0, new_Weight1}, attrs);
return success();
}
};
struct ConcatToDepth2SpacePattern2 : public OpRewriterPatternEx<ConcatOp> {
using OpRewriterPatternEx::OpRewriterPatternEx;

Expand Down Expand Up @@ -497,7 +598,7 @@ struct RemoveInvaidConcatSlice : public OpRewriterPatternEx<ConcatOp> {
void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.insert<ConvertLoadWeightConcatToLoadWeightPattern,
ConcatToDepth2SpacePattern, ConcatToDepth2SpacePattern2,
ConcatToDepth2SpacePattern, ConcatToRope,ConcatToDepth2SpacePattern2,
MergeSliceConcatPattern, RemoveInvaidConcatSlice,
RemoveInvaidShapeConcatInput>(context);
}

0 comments on commit 85c9177

Please sign in to comment.