Skip to content

Commit

Permalink
Prevent partitioning from moving ops across side-effecting ops. (#7639)
Browse files Browse the repository at this point in the history
This is conservative and uses any side-effecting non-streamable op
as a barrier. We could use memory effects to do better things but
this at least should keep us correct. It's assumed that asserts should
be hoisted higher up in the IR during canonicalization.
  • Loading branch information
benvanik committed Nov 12, 2021
1 parent c4e017c commit fabc494
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 9 deletions.
2 changes: 1 addition & 1 deletion iree/compiler/Dialect/Stream/Analysis/Partitioning.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ void dumpPartition(Partition &partition, AsmState &state) {
out.printAsOperand(llvm::dbgs(), state);
});
llvm::dbgs() << "\n OPS:\n";
for (auto *op : partition.ops) {
for (auto *op : llvm::reverse(partition.ops)) {
llvm::dbgs() << " ";
op->print(llvm::dbgs(), state);
llvm::dbgs() << "\n";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ PartitionSet partitionStreamableOpsReference(
SetVector<Operation *> ops;
};
SmallVector<std::unique_ptr<PartitionBuilder>> builders;
llvm::BitVector usableBuilders;

struct OpInfo {
// Which partitions the op is contained within.
Expand All @@ -49,6 +50,18 @@ PartitionSet partitionStreamableOpsReference(
if (op.hasTrait<OpTrait::ConstantLike>()) {
LLVM_DEBUG(llvm::dbgs() << "(ignoring constant)\n");
continue;
} else if (!isa<IREE::Stream::StreamableOpInterface>(op)) {
// Not a streamable op. If it has side-effects then we force a hazard on
// all builders so that we don't move ops across it.
if (!mlir::wouldOpBeTriviallyDead(&op)) {
LLVM_DEBUG({
llvm::dbgs() << "Side-effecting op forcing flush and freeze:\n";
op.dump();
});
usableBuilders.reset();
continue;
}
// Even though not a streamable op we still want to track it below.
}

// Initialize op info for this op - whether streamable or not. We track
Expand Down Expand Up @@ -93,6 +106,7 @@ PartitionSet partitionStreamableOpsReference(
llvm::BitVector candidates(builders.size(), /*t=*/true);
candidates ^= opInfo.hazards;
candidates |= consumers;
candidates &= usableBuilders;

// Prune candidates that do not have a compatible affinity.
for (auto ordinal : candidates.set_bits()) {
Expand Down Expand Up @@ -161,6 +175,7 @@ PartitionSet partitionStreamableOpsReference(
LLVM_DEBUG(llvm::dbgs()
<< "Created partition " << builder->ordinal << "\n");
builders.push_back(std::move(builder));
usableBuilders.resize(builders.size(), /*t=*/true);
}

// Emit partitions in forward order (as they are topologically sorted in
Expand Down
14 changes: 6 additions & 8 deletions iree/compiler/Dialect/Stream/Transforms/ScheduleExecution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,14 @@ struct ExecutePartitionBuilder {
auto fusedLoc = FusedLoc::get(context, locs);

// Find the insertion point in the parent block.
// This is at the last op defining an input as all inputs must be available.
// This is at the last op in the partition.
Operation *insertionPt = nullptr;
for (auto in : partition->ins) {
auto *definingOp = in.getDefiningOp();
if (!definingOp) continue;
if (definingOp->getBlock() != parentBlock) continue;
for (auto *op : partition->ops) {
if (op->getBlock() != parentBlock) continue;
if (!insertionPt) {
insertionPt = definingOp; // first defining op
} else if (insertionPt->isBeforeInBlock(definingOp)) {
insertionPt = definingOp; // moving insertion point down
insertionPt = op; // first defining op
} else if (insertionPt->isBeforeInBlock(op)) {
insertionPt = op; // moving insertion point down
}
}
OpBuilder parentBuilder(context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,42 @@ func @deviceHostDevice() -> !stream.resource<transient> {
// CHECK: return %[[READY_H2D]]
return %5 : !stream.resource<transient>
}

// -----

// Tests that partitioning does not hoist ops across asserts.

// CHECK-LABEL: @dontHoistPastAsserts
func @dontHoistPastAsserts(%arg0: !stream.resource<external>, %arg1: !stream.resource<external>) -> !stream.resource<external> {
%c1 = arith.constant 1 : index
%c20 = arith.constant 20 : index
%c80 = arith.constant 80 : index
%c1280 = arith.constant 1280 : index
%cst = arith.constant 0x7F800000 : f32
%cond_a = arith.constant 0 : i1
%cond_b = arith.constant 0 : i1

// CHECK: stream.async.execute
// CHECK-NEXT: stream.async.splat
%2 = stream.async.splat %cst : f32 -> !stream.resource<transient>{%c1280}
// CHECK-NEXT: stream.async.dispatch @ex::@dispatch_0
%3 = stream.async.dispatch @ex::@dispatch_0[%c1, %c1, %c1](%2, %arg1) : (!stream.resource<transient>{%c1280}, !stream.resource<external>{%c80}) -> %2{%c1280}

// CHECK: "assert A"
assert %cond_a, "assert A"

// CHECK: stream.async.execute
// CHECK-NEXT: stream.async.splat
%4 = stream.async.splat %cst : f32 -> !stream.resource<transient>{%c20}
// CHECK-NEXT: stream.async.dispatch @ex::@dispatch_1
%5 = stream.async.dispatch @ex::@dispatch_1[%c1, %c1, %c1](%arg0, %4) : (!stream.resource<external>{%c20}, !stream.resource<transient>{%c20}) -> %4{%c20}

// CHECK: "assert B"
assert %cond_b, "assert B"

// CHECK: stream.async.execute
// CHECK-NEXT: stream.async.dispatch @ex::@dispatch_2
%6 = stream.async.dispatch @ex::@dispatch_2[%c1, %c1, %c1](%3, %5) : (!stream.resource<transient>{%c1280}, !stream.resource<transient>{%c20}) -> !stream.resource<external>{%c20}

return %6 : !stream.resource<external>
}

0 comments on commit fabc494

Please sign in to comment.