From b5851565e0bf0197ce50f530a284d3ea2ceb66a3 Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Mon, 8 Nov 2021 08:36:28 -0800 Subject: [PATCH] Adding -iree-stream-schedule-execution + -concurrency passes. (#7549) The passes themselves are rather simple and call into a partitioning routine that performs the real work with the intent being that we can have many and specify which one to use based on scoped attributes in the IR (kind of like lowering configs in codegen). Today there's just a reference implementation that does a single level of concurrency. The hope is that someone who actually knows how to write a good partitioning algorithm can contribute something better, but it's at least no worse than what we have today and better than simple ML systems that have no concurrency. Though the passes are similar they operate at different scopes and will have different partitioning algorithms. I thought about trying to unify them however keeping them separate allows us to do things like use a more complex execution partitioning pass while using the same generic concurrency scheduling etc - including disabling the concurrency scheduling entirely for debugging or environments where there may be no benefits to such scheduling (single core execution, etc). It's easy enough to reason about how they could be unified that I wanted to err on the side of flexibility until we have an owner and at least one or two more algorithms we can use to feel out the shape of things. A benefit of the independent execution and concurrency partitioning is that debugging either is much simpler (and there's pretty good `-debug` output). Since the concurrency scheduling operates only within the scheduled execution regions there's no need to worry about host/device interactions or the parent op CFG. --- iree/compiler/Dialect/Stream/Analysis/BUILD | 3 + .../Dialect/Stream/Analysis/CMakeLists.txt | 3 + .../Dialect/Stream/Analysis/Partitioning.cpp | 183 +++++++++ .../Dialect/Stream/Analysis/Partitioning.h | 131 +++++++ .../Partitioning/ReferencePartitioning.cpp | 350 ++++++++++++++++++ iree/compiler/Dialect/Stream/Transforms/BUILD | 2 + .../Dialect/Stream/Transforms/CMakeLists.txt | 2 + .../Dialect/Stream/Transforms/Passes.cpp | 16 + .../Dialect/Stream/Transforms/Passes.h | 7 + .../Dialect/Stream/Transforms/Passes.td | 20 + .../Stream/Transforms/ScheduleConcurrency.cpp | 281 ++++++++++++++ .../Stream/Transforms/ScheduleExecution.cpp | 347 +++++++++++++++++ .../Dialect/Stream/Transforms/test/BUILD | 2 + .../Stream/Transforms/test/CMakeLists.txt | 2 + .../Transforms/test/schedule_concurrency.mlir | 47 +++ .../Transforms/test/schedule_execution.mlir | 104 ++++++ 16 files changed, 1500 insertions(+) create mode 100644 iree/compiler/Dialect/Stream/Analysis/Partitioning.cpp create mode 100644 iree/compiler/Dialect/Stream/Analysis/Partitioning.h create mode 100644 iree/compiler/Dialect/Stream/Analysis/Partitioning/ReferencePartitioning.cpp create mode 100644 iree/compiler/Dialect/Stream/Transforms/ScheduleConcurrency.cpp create mode 100644 iree/compiler/Dialect/Stream/Transforms/ScheduleExecution.cpp create mode 100644 iree/compiler/Dialect/Stream/Transforms/test/schedule_concurrency.mlir create mode 100644 iree/compiler/Dialect/Stream/Transforms/test/schedule_execution.mlir diff --git a/iree/compiler/Dialect/Stream/Analysis/BUILD b/iree/compiler/Dialect/Stream/Analysis/BUILD index ad307dc93d8c..b8e7ce80a325 100644 --- a/iree/compiler/Dialect/Stream/Analysis/BUILD +++ b/iree/compiler/Dialect/Stream/Analysis/BUILD @@ -13,9 +13,12 @@ package( cc_library( name = "Analysis", srcs = [ + "Partitioning.cpp", + "Partitioning/ReferencePartitioning.cpp", "ResourceUsage.cpp", ], hdrs = [ + "Partitioning.h", "ResourceUsage.h", ], deps = [ diff --git a/iree/compiler/Dialect/Stream/Analysis/CMakeLists.txt b/iree/compiler/Dialect/Stream/Analysis/CMakeLists.txt index 2fbacefd6345..7cff2f21f27f 100644 --- a/iree/compiler/Dialect/Stream/Analysis/CMakeLists.txt +++ b/iree/compiler/Dialect/Stream/Analysis/CMakeLists.txt @@ -14,8 +14,11 @@ iree_cc_library( NAME Analysis HDRS + "Partitioning.h" "ResourceUsage.h" SRCS + "Partitioning.cpp" + "Partitioning/ReferencePartitioning.cpp" "ResourceUsage.cpp" DEPS LLVMSupport diff --git a/iree/compiler/Dialect/Stream/Analysis/Partitioning.cpp b/iree/compiler/Dialect/Stream/Analysis/Partitioning.cpp new file mode 100644 index 000000000000..104e6512c39a --- /dev/null +++ b/iree/compiler/Dialect/Stream/Analysis/Partitioning.cpp @@ -0,0 +1,183 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Dialect/Stream/Analysis/Partitioning.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Debug.h" +#include "mlir/IR/AsmState.h" +#include "mlir/IR/PatternMatch.h" + +#define DEBUG_TYPE "iree-stream-partitioning" + +namespace mlir { +namespace iree_compiler { +namespace IREE { +namespace Stream { + +#ifndef NDEBUG + +void dumpPartition(Partition &partition, AsmState &state) { + llvm::dbgs() << " INS:\n "; + llvm::interleaveComma(partition.ins, llvm::dbgs(), [&](Value in) { + in.printAsOperand(llvm::dbgs(), state); + }); + llvm::dbgs() << "\n OUTS:\n "; + llvm::interleaveComma(partition.outs, llvm::dbgs(), [&](Value out) { + out.printAsOperand(llvm::dbgs(), state); + }); + llvm::dbgs() << "\n OPS:\n"; + for (auto *op : partition.ops) { + llvm::dbgs() << " "; + op->print(llvm::dbgs(), state); + llvm::dbgs() << "\n"; + } +} + +void Partition::dump(Operation *parentOp) { + AsmState state(parentOp); + dumpPartition(*this, state); +} + +void PartitionSet::dump(Operation *parentOp) { + AsmState state(parentOp); + for (auto partition : llvm::enumerate(partitions)) { + llvm::dbgs() << "PARTITION[" << partition.index() << "]:\n"; + dumpPartition(partition.value(), state); + } +} + +#else +void Partition::dump(Operation *parentOp) {} +void PartitionSet::dump(Operation *parentOp) {} +#endif // !NDEBUG + +LogicalResult Partition::verify(Location loc) { + // Ensure values are defined either by other ops in the partition or are + // declared as inputs. + SetVector defValues; + for (auto *op : ops) { + for (auto result : op->getResults()) { + defValues.insert(result); + } + } + for (auto *op : ops) { + for (auto operand : op->getOperands()) { + if (!ins.contains(operand) && !defValues.contains(operand)) { + return mlir::emitError(loc) + << "operand not declared in partition inputs or by an op within " + "the partition"; + } + } + } + + // Ensure all outputs come from ops in the partition (or are pass-through + // operands, though those are silly). + for (auto out : outs) { + if (!ins.contains(out) && !defValues.contains(out)) { + return mlir::emitError(loc) << "output not defined by an op within the " + "partition (or captured)"; + } + } + + return success(); +} + +LogicalResult PartitionSet::verify(Location loc) { + // Verify each partition is consistent. + for (auto &partition : partitions) { + if (failed(partition.verify(loc))) return failure(); + } + + // Ensure no partitions duplicate escaping values as we need a single def to + // remap the value in the parent block. + SetVector outs; + for (auto &partition : partitions) { + for (auto out : partition.outs) { + if (outs.contains(out)) { + return mlir::emitError(loc) + << "duplicate value found in partition set outputs"; + } + outs.insert(out); + } + } + + // Ensure a correct topological order of partitions. This only checks the + // order of the partitions and not any ops that aren't covered. We do this + // by walking backwards and checking that no partition captures values + // escaping any partitions after it. + SetVector declaredBelow; + for (auto &partition : llvm::reverse(partitions)) { + for (auto in : partition.ins) { + if (declaredBelow.contains(in)) { + return mlir::emitError(loc) << "partition set out of order; value " + "captured declared as escaping below"; + } + } + for (auto out : partition.outs) { + declaredBelow.insert(out); + } + } + + return success(); +} + +void PartitionSet::topologicalSort() { + if (partitions.empty()) return; + + SetVector unsortedSet; + DenseMap> consumers; + for (auto &partition : partitions) { + unsortedSet.insert(&partition); + for (auto in : partition.ins) { + consumers[in].push_back(&partition); + } + } + + struct DFSState { + SmallVector topologicalCounts; + DenseSet seen; + } state; + std::function postorderWalk; + postorderWalk = [&](Partition *current) { + for (auto out : current->outs) { + for (auto *consumer : consumers[out]) { + postorderWalk(consumer); + } + } + auto it = state.seen.insert(current); + if (/*inserted=*/it.second) { + if (unsortedSet.contains(current)) { + state.topologicalCounts.push_back(current); + } + } + }; + for (auto *partition : unsortedSet) postorderWalk(partition); + + SmallVector sortedSet; + sortedSet.reserve(partitions.size()); + for (auto *partition : llvm::reverse(state.topologicalCounts)) { + sortedSet.push_back(std::move(*partition)); + } + partitions = std::move(sortedSet); +} + +PartitionSet partitionStreamableOps(IREE::Stream::PartitioningConfigAttr config, + Block *block) { + // Only one algorithm today. + return partitionStreamableOpsReference(config, block); +} + +PartitionSet partitionRegionConcurrency( + IREE::Stream::PartitioningConfigAttr config, Block *block) { + // Only one algorithm today. + return partitionRegionConcurrencyReference(config, block); +} + +} // namespace Stream +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir diff --git a/iree/compiler/Dialect/Stream/Analysis/Partitioning.h b/iree/compiler/Dialect/Stream/Analysis/Partitioning.h new file mode 100644 index 000000000000..49f7608a4f09 --- /dev/null +++ b/iree/compiler/Dialect/Stream/Analysis/Partitioning.h @@ -0,0 +1,131 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_COMPILER_DIALECT_STREAM_ANALYSIS_PARTITIONING_H_ +#define IREE_COMPILER_DIALECT_STREAM_ANALYSIS_PARTITIONING_H_ + +#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h" +#include "mlir/IR/Operation.h" +#include "mlir/Support/LLVM.h" + +namespace mlir { +namespace iree_compiler { +namespace IREE { +namespace Stream { + +//===----------------------------------------------------------------------===// +// Data structures +//===----------------------------------------------------------------------===// + +// A single slice of ops. +struct Partition { + // SSA values defined outside of the partition. + // All values not defined by ops in the partition must be declared. + // Multiple partitions may capture the same value. + SetVector ins; + // SSA values defined by the partition with uses outside. + // All values used by ops outside of the partition must be declared. + // Only one partition may produce a new value. + SetVector outs; + // All ops covered by the partition. May contain ops that exist in other + // partitions in cases where the op is to be duplicated. Not all ops are + // streamable (such as constants and arithmetic). + SetVector ops; + + void dump(Operation *parentOp); + + // Verifies that the partition meets the required conditions. + LogicalResult verify(Location loc); +}; + +// A set of all partitions. +struct PartitionSet { + // All partitions in an undefined topological order. + SmallVector partitions; + + // Total number of partitions in the set. + size_t size() const { return partitions.size(); } + // Returns true if the set is empty (no streamable ops). + bool empty() const { return partitions.empty(); } + + void dump(Operation *parentOp); + + // Verifies that the partition set meets the required conditions. + LogicalResult verify(Location loc); + + // Sorts all partitions in a topological order. + void topologicalSort(); +}; + +//===----------------------------------------------------------------------===// +// Stream partitioning algorithms +//===----------------------------------------------------------------------===// +// +// When these algorithms run all streamable operations have had an affinity +// assigned and are lowered out of tensor form. Some resources may have +// lifetimes associated but most will remain unassigned (`!stream.resource<*>`) +// until after partitioning. Note that there may already exist partitioned ops +// in stream.execute regions already. +// +// The intent is that we can use the information we have about each operation, +// the resources moving between them, and where they should execute to better +// partition the DAG. This could optimize for reducing memory transfer between +// devices, reducing latency by minimizing cuts, maximizing concurrency by +// separating non-interfering subgraphs, etc. +// +// This is a well-researched area and there are many algorithms to choose from. +// We'll mostly want to focus on ones that are able to handle multiple critera +// (like memory consumption, compute utilization, available capacity, etc). +// +// See for example: +// dagP: https://github.com/GT-TDAlab/dagP +// Multilevel Algorithms for Acyclic Partitioning of Directed Acyclic Graphs +// https://hal.inria.fr/hal-02306566/document +// METIS: https://github.com/KarypisLab/METIS +// A Fast and High Quality Multilevel Scheme for Partitioning Ireegular +// Graphs +// http://glaros.dtc.umn.edu/gkhome/metis/metis/publications +// SCOTCH: https://www.labri.fr/perso/pelegrin/scotch/ +// Contributions to Parallel Multilevel Graph Partitioning +// https://www.labri.fr/perso/pelegrin/papers/hdr.pdf +// Zoltan: https://cs.sandia.gov/Zoltan/ +// https://cs.sandia.gov/Zoltan/Zoltan_pubs.html +// https://cs.sandia.gov/Zoltan/papers/zoltan_tutorial_dagstuhl09.pdf +// +// And some good papers/overviews: +// - Edge Partitioning of Large Graphs +// https://tel.archives-ouvertes.fr/tel-01956979/document +// + +// Partitions the ops in |block| such that all streamable ops are in one or more +// partitions (with >1 implying duplication). Partitions may contain +// non-streamable ops if it is safe to do so (such as std arithmetic). Not all +// ops in the block will be covered by a partition. +PartitionSet partitionStreamableOps(IREE::Stream::PartitioningConfigAttr config, + Block *block); +PartitionSet partitionRegionConcurrency( + IREE::Stream::PartitioningConfigAttr config, Block *block); + +//===----------------------------------------------------------------------===// +// Reference partitioning +//===----------------------------------------------------------------------===// + +// Naive clustering based solely on correctness with no cost model or weighting. +// Produces the largest possible streams for any given block. Unsatisfactory. +PartitionSet partitionStreamableOpsReference( + IREE::Stream::PartitioningConfigAttr config, Block *block); + +// Similarly poor algorithm to partitionStreamableOpsReference but for use +// within partitioned streams to produce waves of concurrently executable work. +PartitionSet partitionRegionConcurrencyReference( + IREE::Stream::PartitioningConfigAttr config, Block *block); + +} // namespace Stream +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_COMPILER_DIALECT_STREAM_ANALYSIS_PARTITIONING_H_ diff --git a/iree/compiler/Dialect/Stream/Analysis/Partitioning/ReferencePartitioning.cpp b/iree/compiler/Dialect/Stream/Analysis/Partitioning/ReferencePartitioning.cpp new file mode 100644 index 000000000000..1d157b5d65bf --- /dev/null +++ b/iree/compiler/Dialect/Stream/Analysis/Partitioning/ReferencePartitioning.cpp @@ -0,0 +1,350 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Dialect/Stream/Analysis/Partitioning.h" +#include "llvm/ADT/BitVector.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/Support/Debug.h" +#include "mlir/IR/PatternMatch.h" + +#define DEBUG_TYPE "iree-stream-partitioning" + +namespace mlir { +namespace iree_compiler { +namespace IREE { +namespace Stream { + +// This is terrible. See Stream/Analysis/Partition.h for a description of what +// a real implementation would do. We want cost modeling for tie breakers when +// an op could be in multiple partitions, cloning for ops that are not worth +// spanning partitions (like splats), etc. +PartitionSet partitionStreamableOpsReference( + IREE::Stream::PartitioningConfigAttr config, Block *block) { + PartitionSet partitionSet; + + struct PartitionBuilder { + unsigned ordinal; + // Affinity of the partition. + IREE::Stream::AffinityAttr affinity; + // Ops present in the partition; ops may be present in multiple partitions. + SetVector ops; + }; + SmallVector> builders; + + struct OpInfo { + // Which partitions the op is contained within. + llvm::BitVector membership; + // Which partitions transitively depend on this operation. + llvm::BitVector hazards; + }; + DenseMap opInfos; + + for (auto &op : llvm::reverse(*block)) { + // Skip constants; they just add noise (and since they are heavily CSE'd + // they have lots of users to test). + if (op.hasTrait()) { + LLVM_DEBUG(llvm::dbgs() << "(ignoring constant)\n"); + continue; + } + + // Initialize op info for this op - whether streamable or not. We track + // transitive hazards on each op. Note that thanks to the ordering of ops + // in SSA form (_reversed here!_) we know that once we visit this op no + // partition created after it can ever depend on it if it doesn't here. This + // lets us keep the bitvectors small. + auto &opInfo = opInfos[&op]; + opInfo.hazards.reserve(builders.size() + 1); + opInfo.hazards.resize(builders.size(), /*t=*/false); + + IREE::Stream::AffinityAttr affinityAttr; + if (auto affinityOp = dyn_cast(op)) { + affinityAttr = affinityOp.getAffinity(); + } + + LLVM_DEBUG({ + llvm::dbgs() << "====\nPartitioning op:\n"; + op.dump(); + }); + + // Set bits for each partition this op may be able to be placed into. + // We prune the set based on whether the users are part of a transitive + // dependency chain down the use-def chain to a partition. + llvm::BitVector consumers(builders.size(), /*t=*/false); + for (auto user : op.getUsers()) { + auto &userInfo = opInfos[user]; + LLVM_DEBUG({ + llvm::dbgs() << "Testing user:\n"; + user->dump(); + for (auto membershipOrdinal : userInfo.membership.set_bits()) { + llvm::dbgs() << " member of partition " << membershipOrdinal << "\n"; + } + for (auto hazardOrdinal : userInfo.hazards.set_bits()) { + llvm::dbgs() << " hazard w/ partition " << hazardOrdinal << "\n"; + } + }); + consumers |= userInfo.membership; + opInfo.hazards |= userInfo.membership; + opInfo.hazards |= userInfo.hazards; + } + llvm::BitVector candidates(builders.size(), /*t=*/true); + candidates ^= opInfo.hazards; + candidates |= consumers; + + // Prune candidates that do not have a compatible affinity. + for (auto ordinal : candidates.set_bits()) { + if (!IREE::Stream::AffinityAttr::areCompatible( + affinityAttr, builders[ordinal]->affinity)) { + LLVM_DEBUG(llvm::dbgs() + << "Candidate partition " << ordinal << " incompatible\n"); + candidates.reset(ordinal); + } + } + + // If this op is not streamable then bail here; we've still setup the hazard + // map for following iteration. + auto streamableOp = dyn_cast(op); + if (!streamableOp) { + LLVM_DEBUG(llvm::dbgs() << "Not streamable (skip)\n"); + continue; + } + + // First see which partitions are consuming this that we can also safely + // move in to. + consumers &= candidates; + + opInfo.membership.reserve(builders.size() + 1); + opInfo.membership.resize(builders.size(), /*t=*/false); + + // If we have one or more consumers we should go into those first. + if (consumers.any()) { + // If we are a clonable op (like splat) clone us into every partition. + // Otherwise we just pick the first we find (probably a bad heuristic). + bool shouldClone = streamableOp.preferCloneToConsumers(); + for (auto consumerOrdinal : consumers.set_bits()) { + LLVM_DEBUG(llvm::dbgs() << "Cloning into consumer partition " + << consumerOrdinal << "\n"); + builders[consumerOrdinal]->ops.insert(&op); + opInfo.membership.set(consumerOrdinal); + opInfo.hazards.reset(consumerOrdinal); + if (!shouldClone) break; + } + LLVM_DEBUG(llvm::dbgs() << "Handled streamable (continue)\n"); + continue; + } + + // No consumers - if there's any candidate then we'll go into that. + int firstCandidateOrdinal = candidates.find_first(); + if (firstCandidateOrdinal != -1) { + LLVM_DEBUG(llvm::dbgs() << "Moving to first candidate partition " + << firstCandidateOrdinal << " (continue)\n"); + builders[firstCandidateOrdinal]->ops.insert(&op); + opInfo.membership.set(firstCandidateOrdinal); + opInfo.hazards.reset(firstCandidateOrdinal); + continue; + } + + // Mark the op as having hazards against all other partitions. + if (!builders.empty()) { + opInfo.hazards.set(0, builders.size() - 1); + } + + // Create a new partition just for this op. + opInfo.membership.resize(opInfo.membership.size() + 1, /*t=*/true); + auto builder = std::make_unique(); + builder->ordinal = builders.size(); + builder->affinity = affinityAttr; + builder->ops.insert(&op); + LLVM_DEBUG(llvm::dbgs() + << "Created partition " << builder->ordinal << "\n"); + builders.push_back(std::move(builder)); + } + + // Emit partitions in forward order (as they are topologically sorted in + // reverse order from our bottom-up walk). + for (auto &builder : llvm::reverse(builders)) { + Partition partition; + + SetVector consumedValues; + SetVector producedValues; + SetVector escapingValues; + for (auto *op : llvm::reverse(builder->ops)) { + for (auto operand : op->getOperands()) { + consumedValues.insert(operand); + } + for (auto result : op->getResults()) { + producedValues.insert(result); + // TODO(benvanik): optimize this - creates n^2/nlogn behavior. + for (auto user : result.getUsers()) { + if (!builder->ops.contains(user)) { + escapingValues.insert(result); + } + } + } + } + consumedValues.set_subtract(producedValues); + partition.ins = consumedValues; + partition.outs = escapingValues; + + partition.ops = std::move(builder->ops); + partitionSet.partitions.push_back(std::move(partition)); + } + + LLVM_DEBUG(partitionSet.dump(block->getParentOp())); + + return partitionSet; +} + +// This looks to extract a single level of concurrency; we should be recursively +// dividing the block to identify both serial and concurrent regions. +PartitionSet partitionRegionConcurrencyReference( + IREE::Stream::PartitioningConfigAttr config, Block *block) { + PartitionSet waveSet; + + auto favor = config ? config.getFavor().getValue() + : IREE::Stream::Favor::MinPeakMemory; + if (favor == IREE::Stream::Favor::Debug) { + // Disable partitioning when favoring debugability. + return waveSet; + } + + struct PartitionBuilder { + unsigned ordinal; + // Ops present in the wave; ops may be present in multiple waves. + SetVector ops; + }; + SmallVector> builders; + + struct OpInfo { + // Which waves the op is contained within. + llvm::BitVector membership; + // Which waves transitively depend on this operation. + llvm::BitVector hazards; + }; + DenseMap opInfos; + + for (auto &op : llvm::reverse(*block)) { + // Skip constants; they just add noise (and since they are heavily CSE'd + // they have lots of users to test). + if (op.hasTrait()) { + LLVM_DEBUG(llvm::dbgs() << "(ignoring constant)\n"); + continue; + } + + // Initialize op info for this op - whether streamable or not. We track + // transitive hazards on each op. Note that thanks to the ordering of ops + // in SSA form (_reversed here!_) we know that once we visit this op no + // wave created after it can ever depend on it if it doesn't here. This + // lets us keep the bitvectors small. + auto &opInfo = opInfos[&op]; + opInfo.hazards.reserve(builders.size() + 1); + opInfo.hazards.resize(builders.size(), /*t=*/false); + + LLVM_DEBUG({ + llvm::dbgs() << "====\nPartitioning op:\n"; + op.dump(); + }); + + // Set bits for each wave this op may be able to be placed into. + // We prune the set based on whether the users are part of a transitive + // dependency chain down the use-def chain to a wave. + llvm::BitVector consumers(builders.size(), /*t=*/false); + for (auto user : op.getUsers()) { + auto &userInfo = opInfos[user]; + LLVM_DEBUG({ + llvm::dbgs() << "Testing user:\n"; + user->dump(); + for (auto membershipOrdinal : userInfo.membership.set_bits()) { + llvm::dbgs() << " member of wave " << membershipOrdinal << "\n"; + } + int lastHazardOrdinal = userInfo.hazards.find_last(); + if (lastHazardOrdinal != -1) { + llvm::dbgs() << " hazard w/ waves 0-" << lastHazardOrdinal << "\n"; + } + }); + consumers |= userInfo.membership; + opInfo.hazards |= userInfo.membership; + opInfo.hazards |= userInfo.hazards; + } + llvm::BitVector candidates(builders.size(), /*t=*/true); + candidates ^= opInfo.hazards; + + // If this op is not streamable then bail here; we've still setup the hazard + // map for following iteration. + auto streamableOp = dyn_cast(op); + if (!streamableOp || streamableOp.isMetadata()) { + LLVM_DEBUG(llvm::dbgs() << "Not streamable/is subview (skip)\n"); + continue; + } + + opInfo.membership.reserve(builders.size() + 1); + opInfo.membership.resize(builders.size(), /*t=*/false); + + // No consumers - if there's any candidate then we'll go into that. + int firstCandidateOrdinal = favor == IREE::Stream::Favor::MinPeakMemory + ? candidates.find_first() + : candidates.find_last(); + if (firstCandidateOrdinal != -1) { + LLVM_DEBUG(llvm::dbgs() << "Moving to last candidate wave " + << firstCandidateOrdinal << " (continue)\n"); + builders[firstCandidateOrdinal]->ops.insert(&op); + opInfo.membership.set(firstCandidateOrdinal); + opInfo.hazards.set(0, firstCandidateOrdinal); + opInfo.hazards.reset(firstCandidateOrdinal); + continue; + } + + // Mark the op as having hazards against all other waves. + opInfo.hazards.set(0, builders.size()); + + // Create a new wave just for this op. + opInfo.membership.resize(opInfo.membership.size() + 1, /*t=*/true); + auto builder = std::make_unique(); + builder->ordinal = builders.size(); + builder->ops.insert(&op); + LLVM_DEBUG(llvm::dbgs() << "Created wave " << builder->ordinal << "\n"); + builders.push_back(std::move(builder)); + } + + // Emit waves in forward order (as they are topologically sorted in + // reverse order from our bottom-up walk). + for (auto &builder : llvm::reverse(builders)) { + Partition wave; + + SetVector consumedValues; + SetVector producedValues; + SetVector escapingValues; + for (auto *op : llvm::reverse(builder->ops)) { + for (auto operand : op->getOperands()) { + consumedValues.insert(operand); + } + for (auto result : op->getResults()) { + producedValues.insert(result); + // TODO(benvanik): optimize this - creates n^2/nlogn behavior. + for (auto user : result.getUsers()) { + if (!builder->ops.contains(user)) { + escapingValues.insert(result); + } + } + } + } + consumedValues.set_subtract(producedValues); + wave.ins = consumedValues; + wave.outs = escapingValues; + + wave.ops = std::move(builder->ops); + waveSet.partitions.push_back(std::move(wave)); + } + + LLVM_DEBUG(waveSet.dump(block->getParentOp())); + + return waveSet; +} + +} // namespace Stream +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir diff --git a/iree/compiler/Dialect/Stream/Transforms/BUILD b/iree/compiler/Dialect/Stream/Transforms/BUILD index c3d8fe1be6c7..754a920116c8 100644 --- a/iree/compiler/Dialect/Stream/Transforms/BUILD +++ b/iree/compiler/Dialect/Stream/Transforms/BUILD @@ -23,6 +23,8 @@ cc_library( "PassDetail.h", "Passes.cpp", "RefineUsage.cpp", + "ScheduleConcurrency.cpp", + "ScheduleExecution.cpp", "VerifyLowerings.cpp", ], hdrs = [ diff --git a/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt b/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt index 91f94702f981..32ab4a8ae833 100644 --- a/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt +++ b/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt @@ -25,6 +25,8 @@ iree_cc_library( "PassDetail.h" "Passes.cpp" "RefineUsage.cpp" + "ScheduleConcurrency.cpp" + "ScheduleExecution.cpp" "VerifyLowerings.cpp" DEPS ::PassesIncGen diff --git a/iree/compiler/Dialect/Stream/Transforms/Passes.cpp b/iree/compiler/Dialect/Stream/Transforms/Passes.cpp index 49da4e68b540..aa3ac013eead 100644 --- a/iree/compiler/Dialect/Stream/Transforms/Passes.cpp +++ b/iree/compiler/Dialect/Stream/Transforms/Passes.cpp @@ -131,6 +131,22 @@ void buildStreamAsyncPassPipeline(OpPassManager &passManager, // move across devices. We do it before scheduling waves as lifetime doesn't // change and it makes the IR cleaner. passManager.addPass(IREE::Stream::createRefineUsagePass()); + + //---------------------------------------------------------------------------- + // Stream formation and scheduling + //---------------------------------------------------------------------------- + + // Combine async work into execution regions. + passManager.addNestedPass( + IREE::Stream::createScheduleExecutionPass()); + passManager.addNestedPass( + IREE::Stream::createScheduleExecutionPass()); + + // Group concurrently executable work into waves. + passManager.addNestedPass( + IREE::Stream::createScheduleConcurrencyPass()); + passManager.addNestedPass( + IREE::Stream::createScheduleConcurrencyPass()); } //===----------------------------------------------------------------------===// diff --git a/iree/compiler/Dialect/Stream/Transforms/Passes.h b/iree/compiler/Dialect/Stream/Transforms/Passes.h index 8f6981899385..7d3ba8d97e0c 100644 --- a/iree/compiler/Dialect/Stream/Transforms/Passes.h +++ b/iree/compiler/Dialect/Stream/Transforms/Passes.h @@ -88,6 +88,13 @@ std::unique_ptr> createMaterializeCopyOnWritePass(); std::unique_ptr> createElideAsyncCopiesPass(); std::unique_ptr> createRefineUsagePass(); +//===----------------------------------------------------------------------===// +// Stream formation and scheduling +//===----------------------------------------------------------------------===// + +std::unique_ptr> createScheduleExecutionPass(); +std::unique_ptr> createScheduleConcurrencyPass(); + //===----------------------------------------------------------------------===// // Diagnostics //===----------------------------------------------------------------------===// diff --git a/iree/compiler/Dialect/Stream/Transforms/Passes.td b/iree/compiler/Dialect/Stream/Transforms/Passes.td index a9c9701100ec..f60c7e40577d 100644 --- a/iree/compiler/Dialect/Stream/Transforms/Passes.td +++ b/iree/compiler/Dialect/Stream/Transforms/Passes.td @@ -69,6 +69,26 @@ def RefineUsage : }]; } +//===----------------------------------------------------------------------===// +// Stream formation and scheduling +//===----------------------------------------------------------------------===// + +def ScheduleExecution : + Pass<"iree-stream-schedule-execution", ""> { + let summary = "Identifies and groups asynchronous operations into executable regions within function-like regions."; + let constructor = [{ + mlir::iree_compiler::IREE::Stream::createScheduleExecutionPass() + }]; +} + +def ScheduleConcurrency : + Pass<"iree-stream-schedule-concurrency", ""> { + let summary = "Identifies and groups asynchronous operations within executable regions that can run concurrently and groups them into streams."; + let constructor = [{ + mlir::iree_compiler::IREE::Stream::createScheduleConcurrencyPass() + }]; +} + //===----------------------------------------------------------------------===// // Diagnostics //===----------------------------------------------------------------------===// diff --git a/iree/compiler/Dialect/Stream/Transforms/ScheduleConcurrency.cpp b/iree/compiler/Dialect/Stream/Transforms/ScheduleConcurrency.cpp new file mode 100644 index 000000000000..577ed25193cd --- /dev/null +++ b/iree/compiler/Dialect/Stream/Transforms/ScheduleConcurrency.cpp @@ -0,0 +1,281 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Dialect/Stream/Analysis/Partitioning.h" +#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h" +#include "iree/compiler/Dialect/Stream/IR/StreamOps.h" +#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h" +#include "iree/compiler/Dialect/Stream/Transforms/PassDetail.h" +#include "iree/compiler/Dialect/Stream/Transforms/Passes.h" +#include "iree/compiler/Dialect/Util/IR/UtilDialect.h" +#include "iree/compiler/Dialect/Util/IR/UtilOps.h" +#include "iree/compiler/Dialect/Util/IR/UtilTypes.h" +#include "llvm/ADT/BitVector.h" +#include "llvm/Support/Debug.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" + +#define DEBUG_TYPE "iree-stream-schedule-concurrency" + +namespace mlir { +namespace iree_compiler { +namespace IREE { +namespace Stream { +namespace { + +// TODO(benvanik): deduplicate this with ScheduleExecution - almost all of this +// is identical. + +// Incremental builder for a partitioned region of executable work. +// Must be constructed in a topological order of all partitions. +struct WavePartitionBuilder { + explicit WavePartitionBuilder(Block *parentBlock, size_t ordinal, + Partition *partition, + BlockAndValueMapping &parentMapping, + MLIRContext *context) + : ordinal(ordinal), partition(partition), builder(context) { + // Fuse the location of all ops we'll be putting in the partition. + SmallVector locs; + for (auto *op : partition->ops) { + locs.push_back(op->getLoc()); + } + auto fusedLoc = FusedLoc::get(context, locs); + + // Find the insertion point in the parent block. + // This is at the last op defining an input as all inputs must be available. + Operation *insertionPt = nullptr; + for (auto in : partition->ins) { + auto *definingOp = in.getDefiningOp(); + if (!definingOp) continue; + if (definingOp->getBlock() != parentBlock) continue; + if (!insertionPt) { + insertionPt = definingOp; // first defining op + } else if (insertionPt->isBeforeInBlock(definingOp)) { + insertionPt = definingOp; // moving insertion point down + } + } + OpBuilder parentBuilder(context); + if (insertionPt) { + parentBuilder.setInsertionPointAfter(insertionPt); + } else { + parentBuilder.setInsertionPointToStart(parentBlock); + } + + // Gather operands and result types from the declared partition I/O. + // These are values from the original block. Note that because we are + // constructing in order we know that any results of prior partitions are + // in the |parentMapping|. + SmallVector resultTypes; + SmallVector resultSizes; + resultTypes.reserve(partition->outs.size()); + resultSizes.reserve(partition->outs.size()); + for (auto out : partition->outs) { + resultTypes.push_back(out.getType()); + auto resultSize = IREE::Util::SizeAwareTypeInterface::queryValueSize( + fusedLoc, out, parentBuilder); + if (resultSize) resultSizes.push_back(resultSize); + } + SmallVector operands; + SmallVector operandTypes; + SmallVector operandSizes; + operands.reserve(partition->ins.size()); + operandTypes.reserve(partition->ins.size()); + operandSizes.reserve(partition->ins.size()); + for (auto in : partition->ins) { + if (!in.getType().isa()) continue; + operands.push_back(in); + operandTypes.push_back(in.getType()); + auto operandSize = IREE::Util::SizeAwareTypeInterface::queryValueSize( + fusedLoc, in, parentBuilder); + if (operandSize) operandSizes.push_back(operandSize); + } + + // TODO(benvanik): tie operands, or leave to canonicalization. + SmallVector tiedOperands; + concurrentOp = parentBuilder.create( + fusedLoc, resultTypes, resultSizes, operands, operandSizes, + tiedOperands); + + // Add entry block and arguments. + auto &entryBlock = concurrentOp.body().emplaceBlock(); + for (auto args : + llvm::zip(operands, entryBlock.addArguments(operandTypes))) { + mapping.map(std::get<0>(args), std::get<1>(args)); + } + builder = OpBuilder::atBlockBegin(&entryBlock); + + // Remap results for escaping outputs. + for (auto results : llvm::zip(partition->outs, concurrentOp.results())) { + parentMapping.map(std::get<0>(results), std::get<1>(results)); + } + } + + // Visits a block operation and clones it into the partition, if desired. + // + // Slightly suboptimal to be calling this on each op for each partition, + // however we only walk the block once and constructing a multimap would be + // way worse. + // + // Returns true if the operation was cloned into the partition. + bool visit(Operation *op) { + if (!partition->ops.contains(op)) return false; + + // Clone the op into the partition and remap it. + auto *clonedOp = builder.clone(*op, mapping); + (void)clonedOp; + LLVM_DEBUG({ + llvm::dbgs() << "Cloned op into partition " << ordinal << ": "; + clonedOp->dump(); + }); + + return true; + } + + void finish() { + // Gather results mapped into the SSA values we've cloned. + SmallVector results; + SmallVector resultSizes; + results.reserve(partition->outs.size()); + resultSizes.reserve(partition->outs.size()); + for (auto oldResult : partition->outs) { + auto newResult = mapping.lookup(oldResult); + results.push_back(newResult); + auto resultSize = IREE::Util::SizeAwareTypeInterface::queryValueSize( + concurrentOp.getLoc(), newResult, builder); + if (resultSize) resultSizes.push_back(resultSize); + } + builder.create(concurrentOp.getLoc(), results, + resultSizes); + } + + size_t ordinal = -1; + Partition *partition = nullptr; + IREE::Stream::AsyncConcurrentOp concurrentOp; + OpBuilder builder; + BlockAndValueMapping mapping; +}; + +class ScheduleConcurrencyPass + : public ScheduleConcurrencyBase { + public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + } + + void runOnOperation() override { + auto parentOp = dyn_cast(getOperation()); + if (!parentOp || !parentOp.getCallableRegion() || + parentOp.getCallableRegion()->empty()) { + return; + } + for (auto executeOp : + parentOp.getCallableRegion()->getOps()) { + if (failed(runOnRegion(executeOp))) return signalPassFailure(); + } + } + + LogicalResult runOnRegion(IREE::Stream::AsyncExecuteOp parentOp) { + if (parentOp.body().empty()) { + return success(); + } + auto *block = &parentOp.body().front(); + + // Lookup the optional config used to control partitioning. + auto configAttr = IREE::Stream::PartitioningConfigAttr::lookup(parentOp); + + // Compute a set of partitions covering all of the streamable ops in the + // execution region. + auto waveSet = partitionRegionConcurrency(configAttr, block); + if (waveSet.empty()) return success(); + if (failed(waveSet.verify(parentOp.getLoc()))) return failure(); + + // Create partition builders for each partition. + // We'll clone ops into each and insert them into the block at the + // appropriate position (first use... probably). + BlockAndValueMapping mapping; + SmallVector partitionBuilders; + partitionBuilders.reserve(waveSet.size()); + for (auto partition : llvm::enumerate(waveSet.partitions)) { + if (partition.value().ops.size() == 1) continue; + partitionBuilders.push_back(WavePartitionBuilder(block, partition.index(), + &partition.value(), + mapping, &getContext())); + } + + // Walk over each op in the original block and find those that need to be + // partitioned. Each partition builder may clone the op into itself. The + // op will always be left in the original block and we'll rely on DCE to + // remove the ones no longer required. This is not a good approach as it + // creates a lot of new IR (up to O(op*partitions)). + SetVector deadOps; + for (auto &op : *block) { + if (op.hasTrait()) continue; + bool handled = false; + for (auto &partitionBuilder : partitionBuilders) { + handled = partitionBuilder.visit(&op) || handled; + } + if (handled) { + deadOps.insert(&op); + } + } + + // Apply remapping for values captured/escaping partitions. + // We must do this per block as we'll be updating dominated block values. + for (auto &partitionBuilder : partitionBuilders) { + for (auto resultPair : + llvm::zip(partitionBuilder.partition->outs, + partitionBuilder.concurrentOp.results())) { + auto oldResult = std::get<0>(resultPair); + auto newResult = std::get<1>(resultPair); + oldResult.replaceAllUsesWith(newResult); + deadOps.insert(oldResult.getDefiningOp()); + } + partitionBuilder.finish(); + + // Extremely shady reordering of ops we know (should) be safe to move + // after the partition - otherwise, we shouldn't have moved the source + // ops into the partition. + auto concurrentOp = partitionBuilder.concurrentOp; + for (auto user : concurrentOp->getUsers()) { + if (user->getBlock() == concurrentOp->getBlock() && + user->isBeforeInBlock(partitionBuilder.concurrentOp)) { + LLVM_DEBUG({ + llvm::dbgs() << "Shady move of op to after partition: "; + user->dump(); + }); + user->moveAfter(concurrentOp); + } + } + } + for (auto *deadOp : llvm::reverse(deadOps)) { + deadOp->erase(); + } + + LLVM_DEBUG({ + llvm::dbgs() << "\nWaves constructed:\n"; + block->dump(); + }); + return success(); + } +}; + +} // namespace + +std::unique_ptr> createScheduleConcurrencyPass() { + return std::make_unique(); +} + +} // namespace Stream +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir diff --git a/iree/compiler/Dialect/Stream/Transforms/ScheduleExecution.cpp b/iree/compiler/Dialect/Stream/Transforms/ScheduleExecution.cpp new file mode 100644 index 000000000000..8c5c4e5e74a0 --- /dev/null +++ b/iree/compiler/Dialect/Stream/Transforms/ScheduleExecution.cpp @@ -0,0 +1,347 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Dialect/Stream/Analysis/Partitioning.h" +#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h" +#include "iree/compiler/Dialect/Stream/IR/StreamOps.h" +#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h" +#include "iree/compiler/Dialect/Stream/Transforms/PassDetail.h" +#include "iree/compiler/Dialect/Stream/Transforms/Passes.h" +#include "iree/compiler/Dialect/Util/IR/UtilDialect.h" +#include "iree/compiler/Dialect/Util/IR/UtilOps.h" +#include "iree/compiler/Dialect/Util/IR/UtilTypes.h" +#include "llvm/ADT/EquivalenceClasses.h" +#include "llvm/Support/Debug.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#define DEBUG_TYPE "iree-stream-schedule-execution" + +namespace mlir { +namespace iree_compiler { +namespace IREE { +namespace Stream { +namespace { + +// Incremental builder for a partitioned region of executable work. +// Must be constructed in a topological order of all partitions. +struct ExecutePartitionBuilder { + explicit ExecutePartitionBuilder(Block *parentBlock, size_t ordinal, + Partition *partition, + BlockAndValueMapping &parentMapping, + MLIRContext *context) + : ordinal(ordinal), partition(partition), builder(context) { + // Fuse the location of all ops we'll be putting in the partition. + SmallVector locs; + for (auto *op : partition->ops) { + locs.push_back(op->getLoc()); + } + auto fusedLoc = FusedLoc::get(context, locs); + + // Find the insertion point in the parent block. + // This is at the last op defining an input as all inputs must be available. + Operation *insertionPt = nullptr; + for (auto in : partition->ins) { + auto *definingOp = in.getDefiningOp(); + if (!definingOp) continue; + if (definingOp->getBlock() != parentBlock) continue; + if (!insertionPt) { + insertionPt = definingOp; // first defining op + } else if (insertionPt->isBeforeInBlock(definingOp)) { + insertionPt = definingOp; // moving insertion point down + } + } + OpBuilder parentBuilder(context); + if (insertionPt) { + parentBuilder.setInsertionPointAfter(insertionPt); + } else { + parentBuilder.setInsertionPointToStart(parentBlock); + } + + // Gather operands and result types from the declared partition I/O. + // These are values from the original block. Note that because we are + // constructing in order we know that any results of prior partitions are + // in the |parentMapping|. + SmallVector resultTypes; + SmallVector resultSizes; + resultTypes.reserve(partition->outs.size()); + resultSizes.reserve(partition->outs.size()); + for (auto out : partition->outs) { + resultTypes.push_back(out.getType()); + auto resultSize = IREE::Util::SizeAwareTypeInterface::queryValueSize( + fusedLoc, out, parentBuilder); + if (resultSize) resultSizes.push_back(resultSize); + } + SmallVector operands; + SmallVector operandTypes; + SmallVector operandSizes; + operands.reserve(partition->ins.size()); + operandTypes.reserve(partition->ins.size()); + operandSizes.reserve(partition->ins.size()); + for (auto in : partition->ins) { + if (!in.getType().isa()) continue; + operands.push_back(in); + operandTypes.push_back(in.getType()); + auto operandSize = IREE::Util::SizeAwareTypeInterface::queryValueSize( + fusedLoc, in, parentBuilder); + if (operandSize) operandSizes.push_back(operandSize); + } + + // TODO(benvanik): tie operands, or leave to canonicalization. + SmallVector tiedOperands; + executeOp = parentBuilder.create( + fusedLoc, resultTypes, resultSizes, /*awaitTimepoint=*/Value{}, + operands, operandSizes, tiedOperands); + + // Add entry block and arguments. + auto &entryBlock = executeOp.body().emplaceBlock(); + for (auto args : + llvm::zip(operands, entryBlock.addArguments(operandTypes))) { + mapping.map(std::get<0>(args), std::get<1>(args)); + } + builder = OpBuilder::atBlockBegin(&entryBlock); + + // Remap results for escaping outputs. + for (auto results : llvm::zip(partition->outs, executeOp.results())) { + parentMapping.map(std::get<0>(results), std::get<1>(results)); + } + } + + // Visits a block operation and clones it into the partition, if desired. + // + // Slightly suboptimal to be calling this on each op for each partition, + // however we only walk the block once and constructing a multimap would be + // way worse. + // + // Returns true if the operation was cloned into the partition. + bool visit(Operation *op) { + if (!partition->ops.contains(op)) return false; + + // Clone the op into the partition and remap it. + auto *clonedOp = builder.clone(*op, mapping); + (void)clonedOp; + LLVM_DEBUG({ + llvm::dbgs() << "Cloned op into partition " << ordinal << ": "; + clonedOp->dump(); + }); + + return true; + } + + IREE::Stream::AsyncExecuteOp finish() { + // Gather results mapped into the SSA values we've cloned. + SmallVector results; + SmallVector resultSizes; + results.reserve(partition->outs.size()); + resultSizes.reserve(partition->outs.size()); + for (auto oldResult : partition->outs) { + auto newResult = mapping.lookup(oldResult); + results.push_back(newResult); + auto resultSize = IREE::Util::SizeAwareTypeInterface::queryValueSize( + executeOp.getLoc(), newResult, builder); + if (resultSize) resultSizes.push_back(resultSize); + } + builder.create(executeOp.getLoc(), results, + resultSizes); + return executeOp; + } + + size_t ordinal = -1; + Partition *partition = nullptr; + IREE::Stream::AsyncExecuteOp executeOp; + OpBuilder builder; + BlockAndValueMapping mapping; +}; + +// Sorts blocks in dominance order such that the entry block is first and +// all of the following blocks are dominated only by blocks that have come +// before them in the list. +static SmallVector sortBlocksInDominanceOrder(Region ®ion) { + if (region.getBlocks().size() == 1) { + // Dominance info cannot be computed for regions with one block. + return {®ion.getBlocks().front()}; + } + + DominanceInfo dominanceInfo(region.getParentOp()); + llvm::SmallSetVector unmarkedBlocks; + for (auto &block : region.getBlocks()) { + unmarkedBlocks.insert(&block); + } + llvm::SmallSetVector markedBlocks; + std::function visit = [&](Block *block) { + if (markedBlocks.count(block) > 0) return; + for (auto *childBlock : dominanceInfo.getNode(block)->children()) { + visit(childBlock->getBlock()); + } + markedBlocks.insert(block); + }; + while (!unmarkedBlocks.empty()) { + visit(unmarkedBlocks.pop_back_val()); + } + auto orderedBlocks = markedBlocks.takeVector(); + std::reverse(orderedBlocks.begin(), orderedBlocks.end()); + return orderedBlocks; +} + +class ScheduleExecutionPass + : public ScheduleExecutionBase { + public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + } + + void runOnOperation() override { + auto *context = &getContext(); + auto parentOp = dyn_cast(getOperation()); + if (!parentOp || !parentOp.getCallableRegion() || + parentOp.getCallableRegion()->empty()) { + return; + } + + // Lookup the optional config used to control partitioning. + auto configAttr = IREE::Stream::PartitioningConfigAttr::lookup(parentOp); + + // Partition each block on its own. We could try to partition with the CFG + // however that's much more complex - it's easier to handle partitioning + // structured control flow (scf) ops. Note that we do this in dominance + // order so that we are sure if we replace values that dominate other blocks + // they see the correct values. + auto ®ion = *parentOp.getCallableRegion(); + for (auto *block : sortBlocksInDominanceOrder(region)) { + // Compute a set of partitions covering all of the streamable ops in the + // block. + auto partitionSet = partitionStreamableOps(configAttr, block); + if (partitionSet.empty()) continue; + if (failed(partitionSet.verify(parentOp.getLoc()))) { + return signalPassFailure(); + } + + // Create partition builders for each partition. + // We'll clone ops into each and insert them into the block at the + // appropriate position (first use... probably). + BlockAndValueMapping mapping; + SmallVector partitionBuilders; + partitionBuilders.reserve(partitionSet.size()); + for (auto partition : llvm::enumerate(partitionSet.partitions)) { + partitionBuilders.push_back(ExecutePartitionBuilder( + block, partition.index(), &partition.value(), mapping, context)); + } + + // Walk over each op in the original block and find those that need to be + // partitioned. Each partition builder may clone the op into itself. The + // op will always be left in the original block and we'll rely on DCE to + // remove the ones no longer required. This is not a good approach as it + // creates a lot of new IR (up to O(op*partitions)). + SetVector deadOps; + for (auto &op : *block) { + if (op.hasTrait()) continue; + for (auto &partitionBuilder : partitionBuilders) { + partitionBuilder.visit(&op); + } + if (isa(op)) { + deadOps.insert(&op); + } + } + + // Apply remapping for values captured/escaping partitions. + // We must do this per block as we'll be updating dominated block values. + for (auto &partitionBuilder : partitionBuilders) { + // Finish construction and insert the yield. + auto executeOp = partitionBuilder.finish(); + + OpBuilder builder(executeOp); + builder.setInsertionPointAfter(executeOp); + for (auto it : + llvm::zip(partitionBuilder.partition->outs, executeOp.results(), + executeOp.result_sizes())) { + auto oldResult = std::get<0>(it); + auto newResult = std::get<1>(it); + auto newResultSize = std::get<2>(it); + + // Insert one await per result. We could batch them all but that would + // prematurely tie their lifetimes together. By having unique awaits + // we allow propagation to move the waits further to where the values + // are used (including right into other execution regions). + auto awaitOp = builder.create( + executeOp.getLoc(), newResult, newResultSize, + executeOp.result_timepoint()); + if (executeOp.affinity().hasValue()) { + awaitOp.affinityAttr(executeOp.affinityAttr()); + } + + oldResult.replaceAllUsesWith(awaitOp.results().front()); + deadOps.insert(oldResult.getDefiningOp()); + } + + // Extremely shady reordering of ops we know (should) be safe to move + // after the partition - otherwise, we shouldn't have moved the source + // ops into the partition. + SetVector worklist; + for (auto user : executeOp->getUsers()) { + worklist.insert(user); + } + while (!worklist.empty()) { + auto *user = worklist.pop_back_val(); + if (user->getBlock() == executeOp->getBlock() && + user->isBeforeInBlock(executeOp)) { + LLVM_DEBUG({ + llvm::dbgs() << "Shady move of op to after partition: "; + user->dump(); + }); + user->moveAfter(builder.getInsertionBlock(), + builder.getInsertionPoint()); + } + for (auto subUser : user->getUsers()) { + worklist.insert(subUser); + } + } + } + for (auto *deadOp : llvm::reverse(deadOps)) { + deadOp->erase(); + } + + LLVM_DEBUG({ + llvm::dbgs() << "\nPartitions constructed:\n"; + block->dump(); + }); + } + + // Cleanup the dead ops. + // TODO(benvanik): less work here - maybe no patterns to just force folding? + OwningRewritePatternList patterns(context); + for (auto *dialect : context->getLoadedDialects()) { + dialect->getCanonicalizationPatterns(patterns); + } + for (auto *op : context->getRegisteredOperations()) { + op->getCanonicalizationPatterns(patterns, context); + } + FrozenRewritePatternSet frozenPatterns(std::move(patterns)); + if (failed(applyPatternsAndFoldGreedily(getOperation(), frozenPatterns))) { + return signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr> createScheduleExecutionPass() { + return std::make_unique(); +} + +} // namespace Stream +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir diff --git a/iree/compiler/Dialect/Stream/Transforms/test/BUILD b/iree/compiler/Dialect/Stream/Transforms/test/BUILD index 16765b22e3d7..c9a1c8172e15 100644 --- a/iree/compiler/Dialect/Stream/Transforms/test/BUILD +++ b/iree/compiler/Dialect/Stream/Transforms/test/BUILD @@ -23,6 +23,8 @@ iree_lit_test_suite( "materialize_copy_on_write.mlir", "outline_constants.mlir", "refine_usage.mlir", + "schedule_concurrency.mlir", + "schedule_execution.mlir", ], include = ["*.mlir"], ), diff --git a/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt b/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt index 3317a4da47fa..f83d356f790c 100644 --- a/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt +++ b/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt @@ -20,6 +20,8 @@ iree_lit_test_suite( "materialize_copy_on_write.mlir" "outline_constants.mlir" "refine_usage.mlir" + "schedule_concurrency.mlir" + "schedule_execution.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/Stream/Transforms/test/schedule_concurrency.mlir b/iree/compiler/Dialect/Stream/Transforms/test/schedule_concurrency.mlir new file mode 100644 index 000000000000..2cd287ae8c13 --- /dev/null +++ b/iree/compiler/Dialect/Stream/Transforms/test/schedule_concurrency.mlir @@ -0,0 +1,47 @@ +// RUN: iree-opt -split-input-file -pass-pipeline="builtin.func(iree-stream-schedule-concurrency)" %s | IreeFileCheck %s + +// CHECK-LABEL: @partitioning +// CHECK-SAME: (%[[ARG0:.+]]: !stream.resource, %[[ARG1:.+]]: !stream.resource) +func @partitioning(%arg0: !stream.resource, %arg1: !stream.resource) -> !stream.resource { + %c1 = arith.constant 1 : index + %c20 = arith.constant 20 : index + %c80 = arith.constant 80 : index + %c1280 = arith.constant 1280 : index + %cst = arith.constant 0x7F800000 : f32 + // CHECK: stream.async.execute + %results, %result_timepoint = stream.async.execute + // CHECK-SAME: with(%[[ARG1]] as %[[ARG1_CAPTURE:.+]]: !stream.resource{%c80}, + // CHECK-SAME: %[[ARG0]] as %[[ARG0_CAPTURE:.+]]: !stream.resource{%c20}) + with(%arg1 as %arg2: !stream.resource{%c80}, + %arg0 as %arg3: !stream.resource{%c20}) + -> !stream.resource{%c20} { + + // CHECK: %[[CON0:.+]]:2 = stream.async.concurrent with() + // CHECK-SAME: -> (!stream.resource{%c1280}, !stream.resource{%c20}) { + // CHECK-NEXT: %[[SPLAT0:.+]] = stream.async.splat %cst : f32 -> !stream.resource{%c1280} + // CHECK-NEXT: %[[SPLAT1:.+]] = stream.async.splat %cst : f32 -> !stream.resource{%c20} + // CHECK-NEXT: stream.yield %[[SPLAT0]], %[[SPLAT1]] : !stream.resource{%c1280}, !stream.resource{%c20} + + // CHECK: %[[CON1:.+]]:2 = stream.async.concurrent + // CHECK-SAME: with(%[[CON0]]#0 as %[[CON0_0_CAPTURE:.+]]: !stream.resource{%c1280}, + // CHECK-SAME: %[[ARG1_CAPTURE]] as %[[ARG1_CON1_CAPTURE:.+]]: !stream.resource{%c80}, + // CHECK-SAME: %[[ARG0_CAPTURE]] as %[[ARG0_CON1_CAPTURE:.+]]: !stream.resource{%c20}, + // CHECK-SAME: %[[CON0]]#1 as %[[CON0_1_CAPTURE:.+]]: !stream.resource{%c20}) + // CHECK-SAME: -> (!stream.resource{%c1280}, !stream.resource{%c20}) { + // CHECK-NEXT: %[[DISPATCH0:.+]] = stream.async.dispatch @ex::@dispatch_0[%c1, %c1, %c1](%[[CON0_0_CAPTURE]], %[[ARG1_CON1_CAPTURE]]) + // CHECK-NEXT: %[[DISPATCH1:.+]] = stream.async.dispatch @ex::@dispatch_1[%c1, %c1, %c1](%[[ARG0_CON1_CAPTURE]], %[[CON0_1_CAPTURE]]) + // CHECK-NEXT: stream.yield %[[DISPATCH0]], %[[DISPATCH1]] + + // CHECK: %[[DISPATCH2:.+]] = stream.async.dispatch @ex::@dispatch_2[%c1, %c1, %c1](%[[CON1]]#0, %[[CON1]]#1) + // CHECK-NEXT: stream.yield %[[DISPATCH2]] + + %1 = stream.async.splat %cst : f32 -> !stream.resource{%c1280} + %2 = stream.async.dispatch @ex::@dispatch_0[%c1, %c1, %c1](%1, %arg2) : (!stream.resource{%c1280}, !stream.resource{%c80}) -> %1{%c1280} + %3 = stream.async.splat %cst : f32 -> !stream.resource{%c20} + %4 = stream.async.dispatch @ex::@dispatch_1[%c1, %c1, %c1](%arg3, %3) : (!stream.resource{%c20}, !stream.resource{%c20}) -> %3{%c20} + %5 = stream.async.dispatch @ex::@dispatch_2[%c1, %c1, %c1](%2, %4) : (!stream.resource{%c1280}, !stream.resource{%c20}) -> !stream.resource{%c20} + stream.yield %5 : !stream.resource{%c20} + } => !stream.timepoint + %0 = stream.timepoint.await %result_timepoint => %results : !stream.resource{%c20} + return %0 : !stream.resource +} diff --git a/iree/compiler/Dialect/Stream/Transforms/test/schedule_execution.mlir b/iree/compiler/Dialect/Stream/Transforms/test/schedule_execution.mlir new file mode 100644 index 000000000000..4f1b70b26711 --- /dev/null +++ b/iree/compiler/Dialect/Stream/Transforms/test/schedule_execution.mlir @@ -0,0 +1,104 @@ +// RUN: iree-opt -split-input-file -pass-pipeline="builtin.func(iree-stream-schedule-execution)" %s | IreeFileCheck %s + +// Tests basic partitioning of multiple ops. + +// CHECK-LABEL: @partitioning +// CHECK-SAME: (%[[ARG0:.+]]: !stream.resource, %[[ARG1:.+]]: !stream.resource) +func @partitioning(%arg0: !stream.resource, %arg1: !stream.resource) -> !stream.resource { + %c1 = arith.constant 1 : index + %c20 = arith.constant 20 : index + %c80 = arith.constant 80 : index + %c1280 = arith.constant 1280 : index + %cst = arith.constant 0x7F800000 : f32 + // CHECK: %[[RESULT:.+]], %[[TIMEPOINT:.+]] = stream.async.execute + // CHECK-SAME: with(%[[ARG1]] as %[[ARG1_CAPTURE:.+]]: !stream.resource{%c80}, + // CHECK-SAME: %[[ARG0]] as %[[ARG0_CAPTURE:.+]]: !stream.resource{%c20}) + // CHECK-SAME: -> !stream.resource{%c20} { + // CHECK-NEXT: %[[SPLAT0:.+]] = stream.async.splat + %2 = stream.async.splat %cst : f32 -> !stream.resource{%c1280} + // CHECK-NEXT: %[[DISPATCH0:.+]] = stream.async.dispatch @ex::@dispatch_0[%c1, %c1, %c1](%[[SPLAT0]], %[[ARG1_CAPTURE]]) : (!stream.resource{%c1280}, !stream.resource{%c80}) -> %[[SPLAT0]]{%c1280} + %3 = stream.async.dispatch @ex::@dispatch_0[%c1, %c1, %c1](%2, %arg1) : (!stream.resource{%c1280}, !stream.resource{%c80}) -> %2{%c1280} + // CHECK-NEXT: %[[SPLAT1:.+]] = stream.async.splat + %4 = stream.async.splat %cst : f32 -> !stream.resource{%c20} + // CHECK-NEXT: %[[DISPATCH1:.+]] = stream.async.dispatch @ex::@dispatch_1[%c1, %c1, %c1](%[[ARG0_CAPTURE]], %[[SPLAT1]]) : (!stream.resource{%c20}, !stream.resource{%c20}) -> %[[SPLAT1]]{%c20} + %5 = stream.async.dispatch @ex::@dispatch_1[%c1, %c1, %c1](%arg0, %4) : (!stream.resource{%c20}, !stream.resource{%c20}) -> %4{%c20} + // CHECK-NEXT: %[[DISPATCH2:.+]] = stream.async.dispatch @ex::@dispatch_2[%c1, %c1, %c1](%[[DISPATCH0]], %[[DISPATCH1]]) : (!stream.resource{%c1280}, !stream.resource{%c20}) -> !stream.resource{%c20} + %6 = stream.async.dispatch @ex::@dispatch_2[%c1, %c1, %c1](%3, %5) : (!stream.resource{%c1280}, !stream.resource{%c20}) -> !stream.resource{%c20} + // CHECK-NEXT: stream.yield %[[DISPATCH2]] : !stream.resource{%c20} + // CHECK-NEXT: } => !stream.timepoint + // CHECK-NEXT: %[[READY:.+]] = stream.timepoint.await %[[TIMEPOINT]] => %[[RESULT]] : !stream.resource{%c20} + // CHECK-NEXT: return %[[READY]] + return %6 : !stream.resource +} + +// ----- + +// Tests that ops in multiple blocks are partitioned independently and that +// timepoints are chained between the partitions. Note that the dispatches +// happen in-place on the splat and we expect the execution regions to be tied. + +// CHECK-LABEL: @partitionWithinBlocks +func @partitionWithinBlocks(%cond: i1) -> !stream.resource { + %c1 = arith.constant 1 : index + %c1280 = arith.constant 1280 : index + %cst = arith.constant 0x7F800000 : f32 + // CHECK: %[[SPLAT:.+]], %[[SPLAT_TIMEPOINT:.+]] = stream.async.execute + // CHECK: stream.async.splat + %splat = stream.async.splat %cst : f32 -> !stream.resource{%c1280} + // CHECK: cond_br + cond_br %cond, ^bb1, ^bb2 +^bb1: + // CHECK: %[[BB1_RESULT:.+]], %[[BB1_TIMEPOINT:.+]] = stream.async.execute await(%[[SPLAT_TIMEPOINT]]) => + // CHECK-SAME: with(%[[SPLAT]] as %[[BB1_SPLAT:.+]]: !stream.resource{%c1280}) + // CHECK-SAME: -> %[[SPLAT]]{%c1280} + // CHECK: stream.async.dispatch @ex::@dispatch_0[%c1, %c1, %c1](%[[BB1_SPLAT]]) : (!stream.resource{%c1280}) -> %[[BB1_SPLAT]]{%c1280} + %3 = stream.async.dispatch @ex::@dispatch_0[%c1, %c1, %c1](%splat) : (!stream.resource{%c1280}) -> %splat{%c1280} + // CHECK: %[[BB1_READY:.+]] = stream.timepoint.await %[[BB1_TIMEPOINT]] => %[[BB1_RESULT]] + // CHECK: return %[[BB1_READY]] + return %3 : !stream.resource +^bb2: + // CHECK: %[[BB2_RESULT:.+]], %[[BB2_TIMEPOINT:.+]] = stream.async.execute await(%[[SPLAT_TIMEPOINT]]) => + // CHECK-SAME: with(%[[SPLAT]] as %[[BB2_SPLAT:.+]]: !stream.resource{%c1280}) + // CHECK-SAME: -> %[[SPLAT]]{%c1280} + // CHECK: stream.async.dispatch @ex::@dispatch_1[%c1, %c1, %c1](%[[BB2_SPLAT]]) : (!stream.resource{%c1280}) -> %[[BB2_SPLAT]]{%c1280} + %4 = stream.async.dispatch @ex::@dispatch_1[%c1, %c1, %c1](%splat) : (!stream.resource{%c1280}) -> %splat{%c1280} + // CHECK: %[[BB2_READY:.+]] = stream.timepoint.await %[[BB2_TIMEPOINT]] => %[[BB2_RESULT]] + // CHECK: return %[[BB2_READY]] + return %4 : !stream.resource +} + +// ----- + +// Tests a complex device->host->device sequence gets turned into the proper +// execute->await->execute. These data-dependent operations can happen in a +// single block and break the assumption that one block == one partition. + +// CHECK-LABEL: @deviceHostDevice +func @deviceHostDevice() -> !stream.resource { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c123_i8 = arith.constant 123 : i8 + // CHECK: %[[RESULT_D2H:.+]], %[[TIMEPOINT_D2H:.+]] = stream.async.execute with() + // CHECK-SAME: -> !stream.resource{%c1} + // CHECK-NEXT: %[[SPLAT:.+]] = stream.async.splat %c123_i8 + %0 = stream.async.splat %c123_i8 : i8 -> !stream.resource{%c1} + // CHECK-NEXT: %[[TRANSFER_D2H:.+]] = stream.async.transfer %[[SPLAT]] + %1 = stream.async.transfer %0 : !stream.resource{%c1} -> !stream.resource{%c1} + // CHECK-NEXT: stream.yield %[[TRANSFER_D2H]] + // CHECK: %[[READY_D2H:.+]] = stream.timepoint.await %[[TIMEPOINT_D2H]] => %[[RESULT_D2H]] : !stream.resource{%c1} + // CHECK: %[[LOAD:.+]] = stream.async.load %[[READY_D2H]] + %2 = stream.async.load %1[%c0] : !stream.resource{%c1} -> i8 + // CHECK: %[[ADD:.+]] = arith.addi %[[LOAD]], %[[LOAD]] + %3 = arith.addi %2, %2 : i8 + // CHECK: %[[STORE:.+]] = stream.async.store %[[ADD]], %[[READY_D2H]] + %4 = stream.async.store %3, %1[%c0] : i8 -> !stream.resource{%c1} + // CHECK: %[[RESULT_H2D:.+]], %[[TIMEPOINT_H2D:.+]] = stream.async.execute + // CHECK-SAME: with(%[[STORE]] as %[[STORE_CAPTURE:.+]]: !stream.resource{%c1}) + // CHECK-SAME: -> !stream.resource{%c1} + // CHECK-NEXT: %[[TRANSFER_H2D:.+]] = stream.async.transfer %[[STORE_CAPTURE]] + %5 = stream.async.transfer %4 : !stream.resource{%c1} -> !stream.resource{%c1} + // CHECK-NEXT: stream.yield %[[TRANSFER_H2D]] + // CHECK: %[[READY_H2D:.+]] = stream.timepoint.await %[[TIMEPOINT_H2D]] => %[[RESULT_H2D]] : !stream.resource{%c1} + // CHECK: return %[[READY_H2D]] + return %5 : !stream.resource +}