diff --git a/iree/compiler/Dialect/Stream/Analysis/Partitioning.cpp b/iree/compiler/Dialect/Stream/Analysis/Partitioning.cpp index 104e6512c39a..466b6662a456 100644 --- a/iree/compiler/Dialect/Stream/Analysis/Partitioning.cpp +++ b/iree/compiler/Dialect/Stream/Analysis/Partitioning.cpp @@ -30,7 +30,7 @@ void dumpPartition(Partition &partition, AsmState &state) { out.printAsOperand(llvm::dbgs(), state); }); llvm::dbgs() << "\n OPS:\n"; - for (auto *op : partition.ops) { + for (auto *op : llvm::reverse(partition.ops)) { llvm::dbgs() << " "; op->print(llvm::dbgs(), state); llvm::dbgs() << "\n"; diff --git a/iree/compiler/Dialect/Stream/Analysis/Partitioning/ReferencePartitioning.cpp b/iree/compiler/Dialect/Stream/Analysis/Partitioning/ReferencePartitioning.cpp index 1d157b5d65bf..9c478a03e2bd 100644 --- a/iree/compiler/Dialect/Stream/Analysis/Partitioning/ReferencePartitioning.cpp +++ b/iree/compiler/Dialect/Stream/Analysis/Partitioning/ReferencePartitioning.cpp @@ -34,6 +34,7 @@ PartitionSet partitionStreamableOpsReference( SetVector ops; }; SmallVector> builders; + llvm::BitVector usableBuilders; struct OpInfo { // Which partitions the op is contained within. @@ -49,6 +50,18 @@ PartitionSet partitionStreamableOpsReference( if (op.hasTrait()) { LLVM_DEBUG(llvm::dbgs() << "(ignoring constant)\n"); continue; + } else if (!isa(op)) { + // Not a streamable op. If it has side-effects then we force a hazard on + // all builders so that we don't move ops across it. + if (!mlir::wouldOpBeTriviallyDead(&op)) { + LLVM_DEBUG({ + llvm::dbgs() << "Side-effecting op forcing flush and freeze:\n"; + op.dump(); + }); + usableBuilders.reset(); + continue; + } + // Even though not a streamable op we still want to track it below. } // Initialize op info for this op - whether streamable or not. We track @@ -93,6 +106,7 @@ PartitionSet partitionStreamableOpsReference( llvm::BitVector candidates(builders.size(), /*t=*/true); candidates ^= opInfo.hazards; candidates |= consumers; + candidates &= usableBuilders; // Prune candidates that do not have a compatible affinity. for (auto ordinal : candidates.set_bits()) { @@ -161,6 +175,7 @@ PartitionSet partitionStreamableOpsReference( LLVM_DEBUG(llvm::dbgs() << "Created partition " << builder->ordinal << "\n"); builders.push_back(std::move(builder)); + usableBuilders.resize(builders.size(), /*t=*/true); } // Emit partitions in forward order (as they are topologically sorted in diff --git a/iree/compiler/Dialect/Stream/Transforms/ScheduleExecution.cpp b/iree/compiler/Dialect/Stream/Transforms/ScheduleExecution.cpp index 8c5c4e5e74a0..b75e7bcffafb 100644 --- a/iree/compiler/Dialect/Stream/Transforms/ScheduleExecution.cpp +++ b/iree/compiler/Dialect/Stream/Transforms/ScheduleExecution.cpp @@ -51,16 +51,14 @@ struct ExecutePartitionBuilder { auto fusedLoc = FusedLoc::get(context, locs); // Find the insertion point in the parent block. - // This is at the last op defining an input as all inputs must be available. + // This is at the last op in the partition. Operation *insertionPt = nullptr; - for (auto in : partition->ins) { - auto *definingOp = in.getDefiningOp(); - if (!definingOp) continue; - if (definingOp->getBlock() != parentBlock) continue; + for (auto *op : partition->ops) { + if (op->getBlock() != parentBlock) continue; if (!insertionPt) { - insertionPt = definingOp; // first defining op - } else if (insertionPt->isBeforeInBlock(definingOp)) { - insertionPt = definingOp; // moving insertion point down + insertionPt = op; // first defining op + } else if (insertionPt->isBeforeInBlock(op)) { + insertionPt = op; // moving insertion point down } } OpBuilder parentBuilder(context); diff --git a/iree/compiler/Dialect/Stream/Transforms/test/schedule_execution.mlir b/iree/compiler/Dialect/Stream/Transforms/test/schedule_execution.mlir index 4f1b70b26711..f28ecf618476 100644 --- a/iree/compiler/Dialect/Stream/Transforms/test/schedule_execution.mlir +++ b/iree/compiler/Dialect/Stream/Transforms/test/schedule_execution.mlir @@ -102,3 +102,42 @@ func @deviceHostDevice() -> !stream.resource { // CHECK: return %[[READY_H2D]] return %5 : !stream.resource } + +// ----- + +// Tests that partitioning does not hoist ops across asserts. + +// CHECK-LABEL: @dontHoistPastAsserts +func @dontHoistPastAsserts(%arg0: !stream.resource, %arg1: !stream.resource) -> !stream.resource { + %c1 = arith.constant 1 : index + %c20 = arith.constant 20 : index + %c80 = arith.constant 80 : index + %c1280 = arith.constant 1280 : index + %cst = arith.constant 0x7F800000 : f32 + %cond_a = arith.constant 0 : i1 + %cond_b = arith.constant 0 : i1 + + // CHECK: stream.async.execute + // CHECK-NEXT: stream.async.splat + %2 = stream.async.splat %cst : f32 -> !stream.resource{%c1280} + // CHECK-NEXT: stream.async.dispatch @ex::@dispatch_0 + %3 = stream.async.dispatch @ex::@dispatch_0[%c1, %c1, %c1](%2, %arg1) : (!stream.resource{%c1280}, !stream.resource{%c80}) -> %2{%c1280} + + // CHECK: "assert A" + assert %cond_a, "assert A" + + // CHECK: stream.async.execute + // CHECK-NEXT: stream.async.splat + %4 = stream.async.splat %cst : f32 -> !stream.resource{%c20} + // CHECK-NEXT: stream.async.dispatch @ex::@dispatch_1 + %5 = stream.async.dispatch @ex::@dispatch_1[%c1, %c1, %c1](%arg0, %4) : (!stream.resource{%c20}, !stream.resource{%c20}) -> %4{%c20} + + // CHECK: "assert B" + assert %cond_b, "assert B" + + // CHECK: stream.async.execute + // CHECK-NEXT: stream.async.dispatch @ex::@dispatch_2 + %6 = stream.async.dispatch @ex::@dispatch_2[%c1, %c1, %c1](%3, %5) : (!stream.resource{%c1280}, !stream.resource{%c20}) -> !stream.resource{%c20} + + return %6 : !stream.resource +}