diff --git a/iree/compiler/Dialect/Stream/Analysis/BUILD b/iree/compiler/Dialect/Stream/Analysis/BUILD index ad307dc93d8c..b8e7ce80a325 100644 --- a/iree/compiler/Dialect/Stream/Analysis/BUILD +++ b/iree/compiler/Dialect/Stream/Analysis/BUILD @@ -13,9 +13,12 @@ package( cc_library( name = "Analysis", srcs = [ + "Partitioning.cpp", + "Partitioning/ReferencePartitioning.cpp", "ResourceUsage.cpp", ], hdrs = [ + "Partitioning.h", "ResourceUsage.h", ], deps = [ diff --git a/iree/compiler/Dialect/Stream/Analysis/CMakeLists.txt b/iree/compiler/Dialect/Stream/Analysis/CMakeLists.txt index 2fbacefd6345..7cff2f21f27f 100644 --- a/iree/compiler/Dialect/Stream/Analysis/CMakeLists.txt +++ b/iree/compiler/Dialect/Stream/Analysis/CMakeLists.txt @@ -14,8 +14,11 @@ iree_cc_library( NAME Analysis HDRS + "Partitioning.h" "ResourceUsage.h" SRCS + "Partitioning.cpp" + "Partitioning/ReferencePartitioning.cpp" "ResourceUsage.cpp" DEPS LLVMSupport diff --git a/iree/compiler/Dialect/Stream/Analysis/Partitioning.cpp b/iree/compiler/Dialect/Stream/Analysis/Partitioning.cpp new file mode 100644 index 000000000000..104e6512c39a --- /dev/null +++ b/iree/compiler/Dialect/Stream/Analysis/Partitioning.cpp @@ -0,0 +1,183 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Dialect/Stream/Analysis/Partitioning.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Debug.h" +#include "mlir/IR/AsmState.h" +#include "mlir/IR/PatternMatch.h" + +#define DEBUG_TYPE "iree-stream-partitioning" + +namespace mlir { +namespace iree_compiler { +namespace IREE { +namespace Stream { + +#ifndef NDEBUG + +void dumpPartition(Partition &partition, AsmState &state) { + llvm::dbgs() << " INS:\n "; + llvm::interleaveComma(partition.ins, llvm::dbgs(), [&](Value in) { + in.printAsOperand(llvm::dbgs(), state); + }); + llvm::dbgs() << "\n OUTS:\n "; + llvm::interleaveComma(partition.outs, llvm::dbgs(), [&](Value out) { + out.printAsOperand(llvm::dbgs(), state); + }); + llvm::dbgs() << "\n OPS:\n"; + for (auto *op : partition.ops) { + llvm::dbgs() << " "; + op->print(llvm::dbgs(), state); + llvm::dbgs() << "\n"; + } +} + +void Partition::dump(Operation *parentOp) { + AsmState state(parentOp); + dumpPartition(*this, state); +} + +void PartitionSet::dump(Operation *parentOp) { + AsmState state(parentOp); + for (auto partition : llvm::enumerate(partitions)) { + llvm::dbgs() << "PARTITION[" << partition.index() << "]:\n"; + dumpPartition(partition.value(), state); + } +} + +#else +void Partition::dump(Operation *parentOp) {} +void PartitionSet::dump(Operation *parentOp) {} +#endif // !NDEBUG + +LogicalResult Partition::verify(Location loc) { + // Ensure values are defined either by other ops in the partition or are + // declared as inputs. + SetVector defValues; + for (auto *op : ops) { + for (auto result : op->getResults()) { + defValues.insert(result); + } + } + for (auto *op : ops) { + for (auto operand : op->getOperands()) { + if (!ins.contains(operand) && !defValues.contains(operand)) { + return mlir::emitError(loc) + << "operand not declared in partition inputs or by an op within " + "the partition"; + } + } + } + + // Ensure all outputs come from ops in the partition (or are pass-through + // operands, though those are silly). + for (auto out : outs) { + if (!ins.contains(out) && !defValues.contains(out)) { + return mlir::emitError(loc) << "output not defined by an op within the " + "partition (or captured)"; + } + } + + return success(); +} + +LogicalResult PartitionSet::verify(Location loc) { + // Verify each partition is consistent. + for (auto &partition : partitions) { + if (failed(partition.verify(loc))) return failure(); + } + + // Ensure no partitions duplicate escaping values as we need a single def to + // remap the value in the parent block. + SetVector outs; + for (auto &partition : partitions) { + for (auto out : partition.outs) { + if (outs.contains(out)) { + return mlir::emitError(loc) + << "duplicate value found in partition set outputs"; + } + outs.insert(out); + } + } + + // Ensure a correct topological order of partitions. This only checks the + // order of the partitions and not any ops that aren't covered. We do this + // by walking backwards and checking that no partition captures values + // escaping any partitions after it. + SetVector declaredBelow; + for (auto &partition : llvm::reverse(partitions)) { + for (auto in : partition.ins) { + if (declaredBelow.contains(in)) { + return mlir::emitError(loc) << "partition set out of order; value " + "captured declared as escaping below"; + } + } + for (auto out : partition.outs) { + declaredBelow.insert(out); + } + } + + return success(); +} + +void PartitionSet::topologicalSort() { + if (partitions.empty()) return; + + SetVector unsortedSet; + DenseMap> consumers; + for (auto &partition : partitions) { + unsortedSet.insert(&partition); + for (auto in : partition.ins) { + consumers[in].push_back(&partition); + } + } + + struct DFSState { + SmallVector topologicalCounts; + DenseSet seen; + } state; + std::function postorderWalk; + postorderWalk = [&](Partition *current) { + for (auto out : current->outs) { + for (auto *consumer : consumers[out]) { + postorderWalk(consumer); + } + } + auto it = state.seen.insert(current); + if (/*inserted=*/it.second) { + if (unsortedSet.contains(current)) { + state.topologicalCounts.push_back(current); + } + } + }; + for (auto *partition : unsortedSet) postorderWalk(partition); + + SmallVector sortedSet; + sortedSet.reserve(partitions.size()); + for (auto *partition : llvm::reverse(state.topologicalCounts)) { + sortedSet.push_back(std::move(*partition)); + } + partitions = std::move(sortedSet); +} + +PartitionSet partitionStreamableOps(IREE::Stream::PartitioningConfigAttr config, + Block *block) { + // Only one algorithm today. + return partitionStreamableOpsReference(config, block); +} + +PartitionSet partitionRegionConcurrency( + IREE::Stream::PartitioningConfigAttr config, Block *block) { + // Only one algorithm today. + return partitionRegionConcurrencyReference(config, block); +} + +} // namespace Stream +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir diff --git a/iree/compiler/Dialect/Stream/Analysis/Partitioning.h b/iree/compiler/Dialect/Stream/Analysis/Partitioning.h new file mode 100644 index 000000000000..49f7608a4f09 --- /dev/null +++ b/iree/compiler/Dialect/Stream/Analysis/Partitioning.h @@ -0,0 +1,131 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_COMPILER_DIALECT_STREAM_ANALYSIS_PARTITIONING_H_ +#define IREE_COMPILER_DIALECT_STREAM_ANALYSIS_PARTITIONING_H_ + +#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h" +#include "mlir/IR/Operation.h" +#include "mlir/Support/LLVM.h" + +namespace mlir { +namespace iree_compiler { +namespace IREE { +namespace Stream { + +//===----------------------------------------------------------------------===// +// Data structures +//===----------------------------------------------------------------------===// + +// A single slice of ops. +struct Partition { + // SSA values defined outside of the partition. + // All values not defined by ops in the partition must be declared. + // Multiple partitions may capture the same value. + SetVector ins; + // SSA values defined by the partition with uses outside. + // All values used by ops outside of the partition must be declared. + // Only one partition may produce a new value. + SetVector outs; + // All ops covered by the partition. May contain ops that exist in other + // partitions in cases where the op is to be duplicated. Not all ops are + // streamable (such as constants and arithmetic). + SetVector ops; + + void dump(Operation *parentOp); + + // Verifies that the partition meets the required conditions. + LogicalResult verify(Location loc); +}; + +// A set of all partitions. +struct PartitionSet { + // All partitions in an undefined topological order. + SmallVector partitions; + + // Total number of partitions in the set. + size_t size() const { return partitions.size(); } + // Returns true if the set is empty (no streamable ops). + bool empty() const { return partitions.empty(); } + + void dump(Operation *parentOp); + + // Verifies that the partition set meets the required conditions. + LogicalResult verify(Location loc); + + // Sorts all partitions in a topological order. + void topologicalSort(); +}; + +//===----------------------------------------------------------------------===// +// Stream partitioning algorithms +//===----------------------------------------------------------------------===// +// +// When these algorithms run all streamable operations have had an affinity +// assigned and are lowered out of tensor form. Some resources may have +// lifetimes associated but most will remain unassigned (`!stream.resource<*>`) +// until after partitioning. Note that there may already exist partitioned ops +// in stream.execute regions already. +// +// The intent is that we can use the information we have about each operation, +// the resources moving between them, and where they should execute to better +// partition the DAG. This could optimize for reducing memory transfer between +// devices, reducing latency by minimizing cuts, maximizing concurrency by +// separating non-interfering subgraphs, etc. +// +// This is a well-researched area and there are many algorithms to choose from. +// We'll mostly want to focus on ones that are able to handle multiple critera +// (like memory consumption, compute utilization, available capacity, etc). +// +// See for example: +// dagP: https://github.com/GT-TDAlab/dagP +// Multilevel Algorithms for Acyclic Partitioning of Directed Acyclic Graphs +// https://hal.inria.fr/hal-02306566/document +// METIS: https://github.com/KarypisLab/METIS +// A Fast and High Quality Multilevel Scheme for Partitioning Ireegular +// Graphs +// http://glaros.dtc.umn.edu/gkhome/metis/metis/publications +// SCOTCH: https://www.labri.fr/perso/pelegrin/scotch/ +// Contributions to Parallel Multilevel Graph Partitioning +// https://www.labri.fr/perso/pelegrin/papers/hdr.pdf +// Zoltan: https://cs.sandia.gov/Zoltan/ +// https://cs.sandia.gov/Zoltan/Zoltan_pubs.html +// https://cs.sandia.gov/Zoltan/papers/zoltan_tutorial_dagstuhl09.pdf +// +// And some good papers/overviews: +// - Edge Partitioning of Large Graphs +// https://tel.archives-ouvertes.fr/tel-01956979/document +// + +// Partitions the ops in |block| such that all streamable ops are in one or more +// partitions (with >1 implying duplication). Partitions may contain +// non-streamable ops if it is safe to do so (such as std arithmetic). Not all +// ops in the block will be covered by a partition. +PartitionSet partitionStreamableOps(IREE::Stream::PartitioningConfigAttr config, + Block *block); +PartitionSet partitionRegionConcurrency( + IREE::Stream::PartitioningConfigAttr config, Block *block); + +//===----------------------------------------------------------------------===// +// Reference partitioning +//===----------------------------------------------------------------------===// + +// Naive clustering based solely on correctness with no cost model or weighting. +// Produces the largest possible streams for any given block. Unsatisfactory. +PartitionSet partitionStreamableOpsReference( + IREE::Stream::PartitioningConfigAttr config, Block *block); + +// Similarly poor algorithm to partitionStreamableOpsReference but for use +// within partitioned streams to produce waves of concurrently executable work. +PartitionSet partitionRegionConcurrencyReference( + IREE::Stream::PartitioningConfigAttr config, Block *block); + +} // namespace Stream +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_COMPILER_DIALECT_STREAM_ANALYSIS_PARTITIONING_H_ diff --git a/iree/compiler/Dialect/Stream/Analysis/Partitioning/ReferencePartitioning.cpp b/iree/compiler/Dialect/Stream/Analysis/Partitioning/ReferencePartitioning.cpp new file mode 100644 index 000000000000..1d157b5d65bf --- /dev/null +++ b/iree/compiler/Dialect/Stream/Analysis/Partitioning/ReferencePartitioning.cpp @@ -0,0 +1,350 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Dialect/Stream/Analysis/Partitioning.h" +#include "llvm/ADT/BitVector.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/Support/Debug.h" +#include "mlir/IR/PatternMatch.h" + +#define DEBUG_TYPE "iree-stream-partitioning" + +namespace mlir { +namespace iree_compiler { +namespace IREE { +namespace Stream { + +// This is terrible. See Stream/Analysis/Partition.h for a description of what +// a real implementation would do. We want cost modeling for tie breakers when +// an op could be in multiple partitions, cloning for ops that are not worth +// spanning partitions (like splats), etc. +PartitionSet partitionStreamableOpsReference( + IREE::Stream::PartitioningConfigAttr config, Block *block) { + PartitionSet partitionSet; + + struct PartitionBuilder { + unsigned ordinal; + // Affinity of the partition. + IREE::Stream::AffinityAttr affinity; + // Ops present in the partition; ops may be present in multiple partitions. + SetVector ops; + }; + SmallVector> builders; + + struct OpInfo { + // Which partitions the op is contained within. + llvm::BitVector membership; + // Which partitions transitively depend on this operation. + llvm::BitVector hazards; + }; + DenseMap opInfos; + + for (auto &op : llvm::reverse(*block)) { + // Skip constants; they just add noise (and since they are heavily CSE'd + // they have lots of users to test). + if (op.hasTrait()) { + LLVM_DEBUG(llvm::dbgs() << "(ignoring constant)\n"); + continue; + } + + // Initialize op info for this op - whether streamable or not. We track + // transitive hazards on each op. Note that thanks to the ordering of ops + // in SSA form (_reversed here!_) we know that once we visit this op no + // partition created after it can ever depend on it if it doesn't here. This + // lets us keep the bitvectors small. + auto &opInfo = opInfos[&op]; + opInfo.hazards.reserve(builders.size() + 1); + opInfo.hazards.resize(builders.size(), /*t=*/false); + + IREE::Stream::AffinityAttr affinityAttr; + if (auto affinityOp = dyn_cast(op)) { + affinityAttr = affinityOp.getAffinity(); + } + + LLVM_DEBUG({ + llvm::dbgs() << "====\nPartitioning op:\n"; + op.dump(); + }); + + // Set bits for each partition this op may be able to be placed into. + // We prune the set based on whether the users are part of a transitive + // dependency chain down the use-def chain to a partition. + llvm::BitVector consumers(builders.size(), /*t=*/false); + for (auto user : op.getUsers()) { + auto &userInfo = opInfos[user]; + LLVM_DEBUG({ + llvm::dbgs() << "Testing user:\n"; + user->dump(); + for (auto membershipOrdinal : userInfo.membership.set_bits()) { + llvm::dbgs() << " member of partition " << membershipOrdinal << "\n"; + } + for (auto hazardOrdinal : userInfo.hazards.set_bits()) { + llvm::dbgs() << " hazard w/ partition " << hazardOrdinal << "\n"; + } + }); + consumers |= userInfo.membership; + opInfo.hazards |= userInfo.membership; + opInfo.hazards |= userInfo.hazards; + } + llvm::BitVector candidates(builders.size(), /*t=*/true); + candidates ^= opInfo.hazards; + candidates |= consumers; + + // Prune candidates that do not have a compatible affinity. + for (auto ordinal : candidates.set_bits()) { + if (!IREE::Stream::AffinityAttr::areCompatible( + affinityAttr, builders[ordinal]->affinity)) { + LLVM_DEBUG(llvm::dbgs() + << "Candidate partition " << ordinal << " incompatible\n"); + candidates.reset(ordinal); + } + } + + // If this op is not streamable then bail here; we've still setup the hazard + // map for following iteration. + auto streamableOp = dyn_cast(op); + if (!streamableOp) { + LLVM_DEBUG(llvm::dbgs() << "Not streamable (skip)\n"); + continue; + } + + // First see which partitions are consuming this that we can also safely + // move in to. + consumers &= candidates; + + opInfo.membership.reserve(builders.size() + 1); + opInfo.membership.resize(builders.size(), /*t=*/false); + + // If we have one or more consumers we should go into those first. + if (consumers.any()) { + // If we are a clonable op (like splat) clone us into every partition. + // Otherwise we just pick the first we find (probably a bad heuristic). + bool shouldClone = streamableOp.preferCloneToConsumers(); + for (auto consumerOrdinal : consumers.set_bits()) { + LLVM_DEBUG(llvm::dbgs() << "Cloning into consumer partition " + << consumerOrdinal << "\n"); + builders[consumerOrdinal]->ops.insert(&op); + opInfo.membership.set(consumerOrdinal); + opInfo.hazards.reset(consumerOrdinal); + if (!shouldClone) break; + } + LLVM_DEBUG(llvm::dbgs() << "Handled streamable (continue)\n"); + continue; + } + + // No consumers - if there's any candidate then we'll go into that. + int firstCandidateOrdinal = candidates.find_first(); + if (firstCandidateOrdinal != -1) { + LLVM_DEBUG(llvm::dbgs() << "Moving to first candidate partition " + << firstCandidateOrdinal << " (continue)\n"); + builders[firstCandidateOrdinal]->ops.insert(&op); + opInfo.membership.set(firstCandidateOrdinal); + opInfo.hazards.reset(firstCandidateOrdinal); + continue; + } + + // Mark the op as having hazards against all other partitions. + if (!builders.empty()) { + opInfo.hazards.set(0, builders.size() - 1); + } + + // Create a new partition just for this op. + opInfo.membership.resize(opInfo.membership.size() + 1, /*t=*/true); + auto builder = std::make_unique(); + builder->ordinal = builders.size(); + builder->affinity = affinityAttr; + builder->ops.insert(&op); + LLVM_DEBUG(llvm::dbgs() + << "Created partition " << builder->ordinal << "\n"); + builders.push_back(std::move(builder)); + } + + // Emit partitions in forward order (as they are topologically sorted in + // reverse order from our bottom-up walk). + for (auto &builder : llvm::reverse(builders)) { + Partition partition; + + SetVector consumedValues; + SetVector producedValues; + SetVector escapingValues; + for (auto *op : llvm::reverse(builder->ops)) { + for (auto operand : op->getOperands()) { + consumedValues.insert(operand); + } + for (auto result : op->getResults()) { + producedValues.insert(result); + // TODO(benvanik): optimize this - creates n^2/nlogn behavior. + for (auto user : result.getUsers()) { + if (!builder->ops.contains(user)) { + escapingValues.insert(result); + } + } + } + } + consumedValues.set_subtract(producedValues); + partition.ins = consumedValues; + partition.outs = escapingValues; + + partition.ops = std::move(builder->ops); + partitionSet.partitions.push_back(std::move(partition)); + } + + LLVM_DEBUG(partitionSet.dump(block->getParentOp())); + + return partitionSet; +} + +// This looks to extract a single level of concurrency; we should be recursively +// dividing the block to identify both serial and concurrent regions. +PartitionSet partitionRegionConcurrencyReference( + IREE::Stream::PartitioningConfigAttr config, Block *block) { + PartitionSet waveSet; + + auto favor = config ? config.getFavor().getValue() + : IREE::Stream::Favor::MinPeakMemory; + if (favor == IREE::Stream::Favor::Debug) { + // Disable partitioning when favoring debugability. + return waveSet; + } + + struct PartitionBuilder { + unsigned ordinal; + // Ops present in the wave; ops may be present in multiple waves. + SetVector ops; + }; + SmallVector> builders; + + struct OpInfo { + // Which waves the op is contained within. + llvm::BitVector membership; + // Which waves transitively depend on this operation. + llvm::BitVector hazards; + }; + DenseMap opInfos; + + for (auto &op : llvm::reverse(*block)) { + // Skip constants; they just add noise (and since they are heavily CSE'd + // they have lots of users to test). + if (op.hasTrait()) { + LLVM_DEBUG(llvm::dbgs() << "(ignoring constant)\n"); + continue; + } + + // Initialize op info for this op - whether streamable or not. We track + // transitive hazards on each op. Note that thanks to the ordering of ops + // in SSA form (_reversed here!_) we know that once we visit this op no + // wave created after it can ever depend on it if it doesn't here. This + // lets us keep the bitvectors small. + auto &opInfo = opInfos[&op]; + opInfo.hazards.reserve(builders.size() + 1); + opInfo.hazards.resize(builders.size(), /*t=*/false); + + LLVM_DEBUG({ + llvm::dbgs() << "====\nPartitioning op:\n"; + op.dump(); + }); + + // Set bits for each wave this op may be able to be placed into. + // We prune the set based on whether the users are part of a transitive + // dependency chain down the use-def chain to a wave. + llvm::BitVector consumers(builders.size(), /*t=*/false); + for (auto user : op.getUsers()) { + auto &userInfo = opInfos[user]; + LLVM_DEBUG({ + llvm::dbgs() << "Testing user:\n"; + user->dump(); + for (auto membershipOrdinal : userInfo.membership.set_bits()) { + llvm::dbgs() << " member of wave " << membershipOrdinal << "\n"; + } + int lastHazardOrdinal = userInfo.hazards.find_last(); + if (lastHazardOrdinal != -1) { + llvm::dbgs() << " hazard w/ waves 0-" << lastHazardOrdinal << "\n"; + } + }); + consumers |= userInfo.membership; + opInfo.hazards |= userInfo.membership; + opInfo.hazards |= userInfo.hazards; + } + llvm::BitVector candidates(builders.size(), /*t=*/true); + candidates ^= opInfo.hazards; + + // If this op is not streamable then bail here; we've still setup the hazard + // map for following iteration. + auto streamableOp = dyn_cast(op); + if (!streamableOp || streamableOp.isMetadata()) { + LLVM_DEBUG(llvm::dbgs() << "Not streamable/is subview (skip)\n"); + continue; + } + + opInfo.membership.reserve(builders.size() + 1); + opInfo.membership.resize(builders.size(), /*t=*/false); + + // No consumers - if there's any candidate then we'll go into that. + int firstCandidateOrdinal = favor == IREE::Stream::Favor::MinPeakMemory + ? candidates.find_first() + : candidates.find_last(); + if (firstCandidateOrdinal != -1) { + LLVM_DEBUG(llvm::dbgs() << "Moving to last candidate wave " + << firstCandidateOrdinal << " (continue)\n"); + builders[firstCandidateOrdinal]->ops.insert(&op); + opInfo.membership.set(firstCandidateOrdinal); + opInfo.hazards.set(0, firstCandidateOrdinal); + opInfo.hazards.reset(firstCandidateOrdinal); + continue; + } + + // Mark the op as having hazards against all other waves. + opInfo.hazards.set(0, builders.size()); + + // Create a new wave just for this op. + opInfo.membership.resize(opInfo.membership.size() + 1, /*t=*/true); + auto builder = std::make_unique(); + builder->ordinal = builders.size(); + builder->ops.insert(&op); + LLVM_DEBUG(llvm::dbgs() << "Created wave " << builder->ordinal << "\n"); + builders.push_back(std::move(builder)); + } + + // Emit waves in forward order (as they are topologically sorted in + // reverse order from our bottom-up walk). + for (auto &builder : llvm::reverse(builders)) { + Partition wave; + + SetVector consumedValues; + SetVector producedValues; + SetVector escapingValues; + for (auto *op : llvm::reverse(builder->ops)) { + for (auto operand : op->getOperands()) { + consumedValues.insert(operand); + } + for (auto result : op->getResults()) { + producedValues.insert(result); + // TODO(benvanik): optimize this - creates n^2/nlogn behavior. + for (auto user : result.getUsers()) { + if (!builder->ops.contains(user)) { + escapingValues.insert(result); + } + } + } + } + consumedValues.set_subtract(producedValues); + wave.ins = consumedValues; + wave.outs = escapingValues; + + wave.ops = std::move(builder->ops); + waveSet.partitions.push_back(std::move(wave)); + } + + LLVM_DEBUG(waveSet.dump(block->getParentOp())); + + return waveSet; +} + +} // namespace Stream +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir diff --git a/iree/compiler/Dialect/Stream/Transforms/BUILD b/iree/compiler/Dialect/Stream/Transforms/BUILD index c3d8fe1be6c7..754a920116c8 100644 --- a/iree/compiler/Dialect/Stream/Transforms/BUILD +++ b/iree/compiler/Dialect/Stream/Transforms/BUILD @@ -23,6 +23,8 @@ cc_library( "PassDetail.h", "Passes.cpp", "RefineUsage.cpp", + "ScheduleConcurrency.cpp", + "ScheduleExecution.cpp", "VerifyLowerings.cpp", ], hdrs = [ diff --git a/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt b/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt index 91f94702f981..32ab4a8ae833 100644 --- a/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt +++ b/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt @@ -25,6 +25,8 @@ iree_cc_library( "PassDetail.h" "Passes.cpp" "RefineUsage.cpp" + "ScheduleConcurrency.cpp" + "ScheduleExecution.cpp" "VerifyLowerings.cpp" DEPS ::PassesIncGen diff --git a/iree/compiler/Dialect/Stream/Transforms/Passes.cpp b/iree/compiler/Dialect/Stream/Transforms/Passes.cpp index 49da4e68b540..aa3ac013eead 100644 --- a/iree/compiler/Dialect/Stream/Transforms/Passes.cpp +++ b/iree/compiler/Dialect/Stream/Transforms/Passes.cpp @@ -131,6 +131,22 @@ void buildStreamAsyncPassPipeline(OpPassManager &passManager, // move across devices. We do it before scheduling waves as lifetime doesn't // change and it makes the IR cleaner. passManager.addPass(IREE::Stream::createRefineUsagePass()); + + //---------------------------------------------------------------------------- + // Stream formation and scheduling + //---------------------------------------------------------------------------- + + // Combine async work into execution regions. + passManager.addNestedPass( + IREE::Stream::createScheduleExecutionPass()); + passManager.addNestedPass( + IREE::Stream::createScheduleExecutionPass()); + + // Group concurrently executable work into waves. + passManager.addNestedPass( + IREE::Stream::createScheduleConcurrencyPass()); + passManager.addNestedPass( + IREE::Stream::createScheduleConcurrencyPass()); } //===----------------------------------------------------------------------===// diff --git a/iree/compiler/Dialect/Stream/Transforms/Passes.h b/iree/compiler/Dialect/Stream/Transforms/Passes.h index 8f6981899385..7d3ba8d97e0c 100644 --- a/iree/compiler/Dialect/Stream/Transforms/Passes.h +++ b/iree/compiler/Dialect/Stream/Transforms/Passes.h @@ -88,6 +88,13 @@ std::unique_ptr> createMaterializeCopyOnWritePass(); std::unique_ptr> createElideAsyncCopiesPass(); std::unique_ptr> createRefineUsagePass(); +//===----------------------------------------------------------------------===// +// Stream formation and scheduling +//===----------------------------------------------------------------------===// + +std::unique_ptr> createScheduleExecutionPass(); +std::unique_ptr> createScheduleConcurrencyPass(); + //===----------------------------------------------------------------------===// // Diagnostics //===----------------------------------------------------------------------===// diff --git a/iree/compiler/Dialect/Stream/Transforms/Passes.td b/iree/compiler/Dialect/Stream/Transforms/Passes.td index a9c9701100ec..f60c7e40577d 100644 --- a/iree/compiler/Dialect/Stream/Transforms/Passes.td +++ b/iree/compiler/Dialect/Stream/Transforms/Passes.td @@ -69,6 +69,26 @@ def RefineUsage : }]; } +//===----------------------------------------------------------------------===// +// Stream formation and scheduling +//===----------------------------------------------------------------------===// + +def ScheduleExecution : + Pass<"iree-stream-schedule-execution", ""> { + let summary = "Identifies and groups asynchronous operations into executable regions within function-like regions."; + let constructor = [{ + mlir::iree_compiler::IREE::Stream::createScheduleExecutionPass() + }]; +} + +def ScheduleConcurrency : + Pass<"iree-stream-schedule-concurrency", ""> { + let summary = "Identifies and groups asynchronous operations within executable regions that can run concurrently and groups them into streams."; + let constructor = [{ + mlir::iree_compiler::IREE::Stream::createScheduleConcurrencyPass() + }]; +} + //===----------------------------------------------------------------------===// // Diagnostics //===----------------------------------------------------------------------===// diff --git a/iree/compiler/Dialect/Stream/Transforms/ScheduleConcurrency.cpp b/iree/compiler/Dialect/Stream/Transforms/ScheduleConcurrency.cpp new file mode 100644 index 000000000000..577ed25193cd --- /dev/null +++ b/iree/compiler/Dialect/Stream/Transforms/ScheduleConcurrency.cpp @@ -0,0 +1,281 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Dialect/Stream/Analysis/Partitioning.h" +#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h" +#include "iree/compiler/Dialect/Stream/IR/StreamOps.h" +#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h" +#include "iree/compiler/Dialect/Stream/Transforms/PassDetail.h" +#include "iree/compiler/Dialect/Stream/Transforms/Passes.h" +#include "iree/compiler/Dialect/Util/IR/UtilDialect.h" +#include "iree/compiler/Dialect/Util/IR/UtilOps.h" +#include "iree/compiler/Dialect/Util/IR/UtilTypes.h" +#include "llvm/ADT/BitVector.h" +#include "llvm/Support/Debug.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" + +#define DEBUG_TYPE "iree-stream-schedule-concurrency" + +namespace mlir { +namespace iree_compiler { +namespace IREE { +namespace Stream { +namespace { + +// TODO(benvanik): deduplicate this with ScheduleExecution - almost all of this +// is identical. + +// Incremental builder for a partitioned region of executable work. +// Must be constructed in a topological order of all partitions. +struct WavePartitionBuilder { + explicit WavePartitionBuilder(Block *parentBlock, size_t ordinal, + Partition *partition, + BlockAndValueMapping &parentMapping, + MLIRContext *context) + : ordinal(ordinal), partition(partition), builder(context) { + // Fuse the location of all ops we'll be putting in the partition. + SmallVector locs; + for (auto *op : partition->ops) { + locs.push_back(op->getLoc()); + } + auto fusedLoc = FusedLoc::get(context, locs); + + // Find the insertion point in the parent block. + // This is at the last op defining an input as all inputs must be available. + Operation *insertionPt = nullptr; + for (auto in : partition->ins) { + auto *definingOp = in.getDefiningOp(); + if (!definingOp) continue; + if (definingOp->getBlock() != parentBlock) continue; + if (!insertionPt) { + insertionPt = definingOp; // first defining op + } else if (insertionPt->isBeforeInBlock(definingOp)) { + insertionPt = definingOp; // moving insertion point down + } + } + OpBuilder parentBuilder(context); + if (insertionPt) { + parentBuilder.setInsertionPointAfter(insertionPt); + } else { + parentBuilder.setInsertionPointToStart(parentBlock); + } + + // Gather operands and result types from the declared partition I/O. + // These are values from the original block. Note that because we are + // constructing in order we know that any results of prior partitions are + // in the |parentMapping|. + SmallVector resultTypes; + SmallVector resultSizes; + resultTypes.reserve(partition->outs.size()); + resultSizes.reserve(partition->outs.size()); + for (auto out : partition->outs) { + resultTypes.push_back(out.getType()); + auto resultSize = IREE::Util::SizeAwareTypeInterface::queryValueSize( + fusedLoc, out, parentBuilder); + if (resultSize) resultSizes.push_back(resultSize); + } + SmallVector operands; + SmallVector operandTypes; + SmallVector operandSizes; + operands.reserve(partition->ins.size()); + operandTypes.reserve(partition->ins.size()); + operandSizes.reserve(partition->ins.size()); + for (auto in : partition->ins) { + if (!in.getType().isa()) continue; + operands.push_back(in); + operandTypes.push_back(in.getType()); + auto operandSize = IREE::Util::SizeAwareTypeInterface::queryValueSize( + fusedLoc, in, parentBuilder); + if (operandSize) operandSizes.push_back(operandSize); + } + + // TODO(benvanik): tie operands, or leave to canonicalization. + SmallVector tiedOperands; + concurrentOp = parentBuilder.create( + fusedLoc, resultTypes, resultSizes, operands, operandSizes, + tiedOperands); + + // Add entry block and arguments. + auto &entryBlock = concurrentOp.body().emplaceBlock(); + for (auto args : + llvm::zip(operands, entryBlock.addArguments(operandTypes))) { + mapping.map(std::get<0>(args), std::get<1>(args)); + } + builder = OpBuilder::atBlockBegin(&entryBlock); + + // Remap results for escaping outputs. + for (auto results : llvm::zip(partition->outs, concurrentOp.results())) { + parentMapping.map(std::get<0>(results), std::get<1>(results)); + } + } + + // Visits a block operation and clones it into the partition, if desired. + // + // Slightly suboptimal to be calling this on each op for each partition, + // however we only walk the block once and constructing a multimap would be + // way worse. + // + // Returns true if the operation was cloned into the partition. + bool visit(Operation *op) { + if (!partition->ops.contains(op)) return false; + + // Clone the op into the partition and remap it. + auto *clonedOp = builder.clone(*op, mapping); + (void)clonedOp; + LLVM_DEBUG({ + llvm::dbgs() << "Cloned op into partition " << ordinal << ": "; + clonedOp->dump(); + }); + + return true; + } + + void finish() { + // Gather results mapped into the SSA values we've cloned. + SmallVector results; + SmallVector resultSizes; + results.reserve(partition->outs.size()); + resultSizes.reserve(partition->outs.size()); + for (auto oldResult : partition->outs) { + auto newResult = mapping.lookup(oldResult); + results.push_back(newResult); + auto resultSize = IREE::Util::SizeAwareTypeInterface::queryValueSize( + concurrentOp.getLoc(), newResult, builder); + if (resultSize) resultSizes.push_back(resultSize); + } + builder.create(concurrentOp.getLoc(), results, + resultSizes); + } + + size_t ordinal = -1; + Partition *partition = nullptr; + IREE::Stream::AsyncConcurrentOp concurrentOp; + OpBuilder builder; + BlockAndValueMapping mapping; +}; + +class ScheduleConcurrencyPass + : public ScheduleConcurrencyBase { + public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + } + + void runOnOperation() override { + auto parentOp = dyn_cast(getOperation()); + if (!parentOp || !parentOp.getCallableRegion() || + parentOp.getCallableRegion()->empty()) { + return; + } + for (auto executeOp : + parentOp.getCallableRegion()->getOps()) { + if (failed(runOnRegion(executeOp))) return signalPassFailure(); + } + } + + LogicalResult runOnRegion(IREE::Stream::AsyncExecuteOp parentOp) { + if (parentOp.body().empty()) { + return success(); + } + auto *block = &parentOp.body().front(); + + // Lookup the optional config used to control partitioning. + auto configAttr = IREE::Stream::PartitioningConfigAttr::lookup(parentOp); + + // Compute a set of partitions covering all of the streamable ops in the + // execution region. + auto waveSet = partitionRegionConcurrency(configAttr, block); + if (waveSet.empty()) return success(); + if (failed(waveSet.verify(parentOp.getLoc()))) return failure(); + + // Create partition builders for each partition. + // We'll clone ops into each and insert them into the block at the + // appropriate position (first use... probably). + BlockAndValueMapping mapping; + SmallVector partitionBuilders; + partitionBuilders.reserve(waveSet.size()); + for (auto partition : llvm::enumerate(waveSet.partitions)) { + if (partition.value().ops.size() == 1) continue; + partitionBuilders.push_back(WavePartitionBuilder(block, partition.index(), + &partition.value(), + mapping, &getContext())); + } + + // Walk over each op in the original block and find those that need to be + // partitioned. Each partition builder may clone the op into itself. The + // op will always be left in the original block and we'll rely on DCE to + // remove the ones no longer required. This is not a good approach as it + // creates a lot of new IR (up to O(op*partitions)). + SetVector deadOps; + for (auto &op : *block) { + if (op.hasTrait()) continue; + bool handled = false; + for (auto &partitionBuilder : partitionBuilders) { + handled = partitionBuilder.visit(&op) || handled; + } + if (handled) { + deadOps.insert(&op); + } + } + + // Apply remapping for values captured/escaping partitions. + // We must do this per block as we'll be updating dominated block values. + for (auto &partitionBuilder : partitionBuilders) { + for (auto resultPair : + llvm::zip(partitionBuilder.partition->outs, + partitionBuilder.concurrentOp.results())) { + auto oldResult = std::get<0>(resultPair); + auto newResult = std::get<1>(resultPair); + oldResult.replaceAllUsesWith(newResult); + deadOps.insert(oldResult.getDefiningOp()); + } + partitionBuilder.finish(); + + // Extremely shady reordering of ops we know (should) be safe to move + // after the partition - otherwise, we shouldn't have moved the source + // ops into the partition. + auto concurrentOp = partitionBuilder.concurrentOp; + for (auto user : concurrentOp->getUsers()) { + if (user->getBlock() == concurrentOp->getBlock() && + user->isBeforeInBlock(partitionBuilder.concurrentOp)) { + LLVM_DEBUG({ + llvm::dbgs() << "Shady move of op to after partition: "; + user->dump(); + }); + user->moveAfter(concurrentOp); + } + } + } + for (auto *deadOp : llvm::reverse(deadOps)) { + deadOp->erase(); + } + + LLVM_DEBUG({ + llvm::dbgs() << "\nWaves constructed:\n"; + block->dump(); + }); + return success(); + } +}; + +} // namespace + +std::unique_ptr> createScheduleConcurrencyPass() { + return std::make_unique(); +} + +} // namespace Stream +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir diff --git a/iree/compiler/Dialect/Stream/Transforms/ScheduleExecution.cpp b/iree/compiler/Dialect/Stream/Transforms/ScheduleExecution.cpp new file mode 100644 index 000000000000..8c5c4e5e74a0 --- /dev/null +++ b/iree/compiler/Dialect/Stream/Transforms/ScheduleExecution.cpp @@ -0,0 +1,347 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Dialect/Stream/Analysis/Partitioning.h" +#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h" +#include "iree/compiler/Dialect/Stream/IR/StreamOps.h" +#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h" +#include "iree/compiler/Dialect/Stream/Transforms/PassDetail.h" +#include "iree/compiler/Dialect/Stream/Transforms/Passes.h" +#include "iree/compiler/Dialect/Util/IR/UtilDialect.h" +#include "iree/compiler/Dialect/Util/IR/UtilOps.h" +#include "iree/compiler/Dialect/Util/IR/UtilTypes.h" +#include "llvm/ADT/EquivalenceClasses.h" +#include "llvm/Support/Debug.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#define DEBUG_TYPE "iree-stream-schedule-execution" + +namespace mlir { +namespace iree_compiler { +namespace IREE { +namespace Stream { +namespace { + +// Incremental builder for a partitioned region of executable work. +// Must be constructed in a topological order of all partitions. +struct ExecutePartitionBuilder { + explicit ExecutePartitionBuilder(Block *parentBlock, size_t ordinal, + Partition *partition, + BlockAndValueMapping &parentMapping, + MLIRContext *context) + : ordinal(ordinal), partition(partition), builder(context) { + // Fuse the location of all ops we'll be putting in the partition. + SmallVector locs; + for (auto *op : partition->ops) { + locs.push_back(op->getLoc()); + } + auto fusedLoc = FusedLoc::get(context, locs); + + // Find the insertion point in the parent block. + // This is at the last op defining an input as all inputs must be available. + Operation *insertionPt = nullptr; + for (auto in : partition->ins) { + auto *definingOp = in.getDefiningOp(); + if (!definingOp) continue; + if (definingOp->getBlock() != parentBlock) continue; + if (!insertionPt) { + insertionPt = definingOp; // first defining op + } else if (insertionPt->isBeforeInBlock(definingOp)) { + insertionPt = definingOp; // moving insertion point down + } + } + OpBuilder parentBuilder(context); + if (insertionPt) { + parentBuilder.setInsertionPointAfter(insertionPt); + } else { + parentBuilder.setInsertionPointToStart(parentBlock); + } + + // Gather operands and result types from the declared partition I/O. + // These are values from the original block. Note that because we are + // constructing in order we know that any results of prior partitions are + // in the |parentMapping|. + SmallVector resultTypes; + SmallVector resultSizes; + resultTypes.reserve(partition->outs.size()); + resultSizes.reserve(partition->outs.size()); + for (auto out : partition->outs) { + resultTypes.push_back(out.getType()); + auto resultSize = IREE::Util::SizeAwareTypeInterface::queryValueSize( + fusedLoc, out, parentBuilder); + if (resultSize) resultSizes.push_back(resultSize); + } + SmallVector operands; + SmallVector operandTypes; + SmallVector operandSizes; + operands.reserve(partition->ins.size()); + operandTypes.reserve(partition->ins.size()); + operandSizes.reserve(partition->ins.size()); + for (auto in : partition->ins) { + if (!in.getType().isa()) continue; + operands.push_back(in); + operandTypes.push_back(in.getType()); + auto operandSize = IREE::Util::SizeAwareTypeInterface::queryValueSize( + fusedLoc, in, parentBuilder); + if (operandSize) operandSizes.push_back(operandSize); + } + + // TODO(benvanik): tie operands, or leave to canonicalization. + SmallVector tiedOperands; + executeOp = parentBuilder.create( + fusedLoc, resultTypes, resultSizes, /*awaitTimepoint=*/Value{}, + operands, operandSizes, tiedOperands); + + // Add entry block and arguments. + auto &entryBlock = executeOp.body().emplaceBlock(); + for (auto args : + llvm::zip(operands, entryBlock.addArguments(operandTypes))) { + mapping.map(std::get<0>(args), std::get<1>(args)); + } + builder = OpBuilder::atBlockBegin(&entryBlock); + + // Remap results for escaping outputs. + for (auto results : llvm::zip(partition->outs, executeOp.results())) { + parentMapping.map(std::get<0>(results), std::get<1>(results)); + } + } + + // Visits a block operation and clones it into the partition, if desired. + // + // Slightly suboptimal to be calling this on each op for each partition, + // however we only walk the block once and constructing a multimap would be + // way worse. + // + // Returns true if the operation was cloned into the partition. + bool visit(Operation *op) { + if (!partition->ops.contains(op)) return false; + + // Clone the op into the partition and remap it. + auto *clonedOp = builder.clone(*op, mapping); + (void)clonedOp; + LLVM_DEBUG({ + llvm::dbgs() << "Cloned op into partition " << ordinal << ": "; + clonedOp->dump(); + }); + + return true; + } + + IREE::Stream::AsyncExecuteOp finish() { + // Gather results mapped into the SSA values we've cloned. + SmallVector results; + SmallVector resultSizes; + results.reserve(partition->outs.size()); + resultSizes.reserve(partition->outs.size()); + for (auto oldResult : partition->outs) { + auto newResult = mapping.lookup(oldResult); + results.push_back(newResult); + auto resultSize = IREE::Util::SizeAwareTypeInterface::queryValueSize( + executeOp.getLoc(), newResult, builder); + if (resultSize) resultSizes.push_back(resultSize); + } + builder.create(executeOp.getLoc(), results, + resultSizes); + return executeOp; + } + + size_t ordinal = -1; + Partition *partition = nullptr; + IREE::Stream::AsyncExecuteOp executeOp; + OpBuilder builder; + BlockAndValueMapping mapping; +}; + +// Sorts blocks in dominance order such that the entry block is first and +// all of the following blocks are dominated only by blocks that have come +// before them in the list. +static SmallVector sortBlocksInDominanceOrder(Region ®ion) { + if (region.getBlocks().size() == 1) { + // Dominance info cannot be computed for regions with one block. + return {®ion.getBlocks().front()}; + } + + DominanceInfo dominanceInfo(region.getParentOp()); + llvm::SmallSetVector unmarkedBlocks; + for (auto &block : region.getBlocks()) { + unmarkedBlocks.insert(&block); + } + llvm::SmallSetVector markedBlocks; + std::function visit = [&](Block *block) { + if (markedBlocks.count(block) > 0) return; + for (auto *childBlock : dominanceInfo.getNode(block)->children()) { + visit(childBlock->getBlock()); + } + markedBlocks.insert(block); + }; + while (!unmarkedBlocks.empty()) { + visit(unmarkedBlocks.pop_back_val()); + } + auto orderedBlocks = markedBlocks.takeVector(); + std::reverse(orderedBlocks.begin(), orderedBlocks.end()); + return orderedBlocks; +} + +class ScheduleExecutionPass + : public ScheduleExecutionBase { + public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + } + + void runOnOperation() override { + auto *context = &getContext(); + auto parentOp = dyn_cast(getOperation()); + if (!parentOp || !parentOp.getCallableRegion() || + parentOp.getCallableRegion()->empty()) { + return; + } + + // Lookup the optional config used to control partitioning. + auto configAttr = IREE::Stream::PartitioningConfigAttr::lookup(parentOp); + + // Partition each block on its own. We could try to partition with the CFG + // however that's much more complex - it's easier to handle partitioning + // structured control flow (scf) ops. Note that we do this in dominance + // order so that we are sure if we replace values that dominate other blocks + // they see the correct values. + auto ®ion = *parentOp.getCallableRegion(); + for (auto *block : sortBlocksInDominanceOrder(region)) { + // Compute a set of partitions covering all of the streamable ops in the + // block. + auto partitionSet = partitionStreamableOps(configAttr, block); + if (partitionSet.empty()) continue; + if (failed(partitionSet.verify(parentOp.getLoc()))) { + return signalPassFailure(); + } + + // Create partition builders for each partition. + // We'll clone ops into each and insert them into the block at the + // appropriate position (first use... probably). + BlockAndValueMapping mapping; + SmallVector partitionBuilders; + partitionBuilders.reserve(partitionSet.size()); + for (auto partition : llvm::enumerate(partitionSet.partitions)) { + partitionBuilders.push_back(ExecutePartitionBuilder( + block, partition.index(), &partition.value(), mapping, context)); + } + + // Walk over each op in the original block and find those that need to be + // partitioned. Each partition builder may clone the op into itself. The + // op will always be left in the original block and we'll rely on DCE to + // remove the ones no longer required. This is not a good approach as it + // creates a lot of new IR (up to O(op*partitions)). + SetVector deadOps; + for (auto &op : *block) { + if (op.hasTrait()) continue; + for (auto &partitionBuilder : partitionBuilders) { + partitionBuilder.visit(&op); + } + if (isa(op)) { + deadOps.insert(&op); + } + } + + // Apply remapping for values captured/escaping partitions. + // We must do this per block as we'll be updating dominated block values. + for (auto &partitionBuilder : partitionBuilders) { + // Finish construction and insert the yield. + auto executeOp = partitionBuilder.finish(); + + OpBuilder builder(executeOp); + builder.setInsertionPointAfter(executeOp); + for (auto it : + llvm::zip(partitionBuilder.partition->outs, executeOp.results(), + executeOp.result_sizes())) { + auto oldResult = std::get<0>(it); + auto newResult = std::get<1>(it); + auto newResultSize = std::get<2>(it); + + // Insert one await per result. We could batch them all but that would + // prematurely tie their lifetimes together. By having unique awaits + // we allow propagation to move the waits further to where the values + // are used (including right into other execution regions). + auto awaitOp = builder.create( + executeOp.getLoc(), newResult, newResultSize, + executeOp.result_timepoint()); + if (executeOp.affinity().hasValue()) { + awaitOp.affinityAttr(executeOp.affinityAttr()); + } + + oldResult.replaceAllUsesWith(awaitOp.results().front()); + deadOps.insert(oldResult.getDefiningOp()); + } + + // Extremely shady reordering of ops we know (should) be safe to move + // after the partition - otherwise, we shouldn't have moved the source + // ops into the partition. + SetVector worklist; + for (auto user : executeOp->getUsers()) { + worklist.insert(user); + } + while (!worklist.empty()) { + auto *user = worklist.pop_back_val(); + if (user->getBlock() == executeOp->getBlock() && + user->isBeforeInBlock(executeOp)) { + LLVM_DEBUG({ + llvm::dbgs() << "Shady move of op to after partition: "; + user->dump(); + }); + user->moveAfter(builder.getInsertionBlock(), + builder.getInsertionPoint()); + } + for (auto subUser : user->getUsers()) { + worklist.insert(subUser); + } + } + } + for (auto *deadOp : llvm::reverse(deadOps)) { + deadOp->erase(); + } + + LLVM_DEBUG({ + llvm::dbgs() << "\nPartitions constructed:\n"; + block->dump(); + }); + } + + // Cleanup the dead ops. + // TODO(benvanik): less work here - maybe no patterns to just force folding? + OwningRewritePatternList patterns(context); + for (auto *dialect : context->getLoadedDialects()) { + dialect->getCanonicalizationPatterns(patterns); + } + for (auto *op : context->getRegisteredOperations()) { + op->getCanonicalizationPatterns(patterns, context); + } + FrozenRewritePatternSet frozenPatterns(std::move(patterns)); + if (failed(applyPatternsAndFoldGreedily(getOperation(), frozenPatterns))) { + return signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr> createScheduleExecutionPass() { + return std::make_unique(); +} + +} // namespace Stream +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir diff --git a/iree/compiler/Dialect/Stream/Transforms/test/BUILD b/iree/compiler/Dialect/Stream/Transforms/test/BUILD index 16765b22e3d7..c9a1c8172e15 100644 --- a/iree/compiler/Dialect/Stream/Transforms/test/BUILD +++ b/iree/compiler/Dialect/Stream/Transforms/test/BUILD @@ -23,6 +23,8 @@ iree_lit_test_suite( "materialize_copy_on_write.mlir", "outline_constants.mlir", "refine_usage.mlir", + "schedule_concurrency.mlir", + "schedule_execution.mlir", ], include = ["*.mlir"], ), diff --git a/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt b/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt index 3317a4da47fa..f83d356f790c 100644 --- a/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt +++ b/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt @@ -20,6 +20,8 @@ iree_lit_test_suite( "materialize_copy_on_write.mlir" "outline_constants.mlir" "refine_usage.mlir" + "schedule_concurrency.mlir" + "schedule_execution.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/Stream/Transforms/test/schedule_concurrency.mlir b/iree/compiler/Dialect/Stream/Transforms/test/schedule_concurrency.mlir new file mode 100644 index 000000000000..2cd287ae8c13 --- /dev/null +++ b/iree/compiler/Dialect/Stream/Transforms/test/schedule_concurrency.mlir @@ -0,0 +1,47 @@ +// RUN: iree-opt -split-input-file -pass-pipeline="builtin.func(iree-stream-schedule-concurrency)" %s | IreeFileCheck %s + +// CHECK-LABEL: @partitioning +// CHECK-SAME: (%[[ARG0:.+]]: !stream.resource, %[[ARG1:.+]]: !stream.resource) +func @partitioning(%arg0: !stream.resource, %arg1: !stream.resource) -> !stream.resource { + %c1 = arith.constant 1 : index + %c20 = arith.constant 20 : index + %c80 = arith.constant 80 : index + %c1280 = arith.constant 1280 : index + %cst = arith.constant 0x7F800000 : f32 + // CHECK: stream.async.execute + %results, %result_timepoint = stream.async.execute + // CHECK-SAME: with(%[[ARG1]] as %[[ARG1_CAPTURE:.+]]: !stream.resource{%c80}, + // CHECK-SAME: %[[ARG0]] as %[[ARG0_CAPTURE:.+]]: !stream.resource{%c20}) + with(%arg1 as %arg2: !stream.resource{%c80}, + %arg0 as %arg3: !stream.resource{%c20}) + -> !stream.resource{%c20} { + + // CHECK: %[[CON0:.+]]:2 = stream.async.concurrent with() + // CHECK-SAME: -> (!stream.resource{%c1280}, !stream.resource{%c20}) { + // CHECK-NEXT: %[[SPLAT0:.+]] = stream.async.splat %cst : f32 -> !stream.resource{%c1280} + // CHECK-NEXT: %[[SPLAT1:.+]] = stream.async.splat %cst : f32 -> !stream.resource{%c20} + // CHECK-NEXT: stream.yield %[[SPLAT0]], %[[SPLAT1]] : !stream.resource{%c1280}, !stream.resource{%c20} + + // CHECK: %[[CON1:.+]]:2 = stream.async.concurrent + // CHECK-SAME: with(%[[CON0]]#0 as %[[CON0_0_CAPTURE:.+]]: !stream.resource{%c1280}, + // CHECK-SAME: %[[ARG1_CAPTURE]] as %[[ARG1_CON1_CAPTURE:.+]]: !stream.resource{%c80}, + // CHECK-SAME: %[[ARG0_CAPTURE]] as %[[ARG0_CON1_CAPTURE:.+]]: !stream.resource{%c20}, + // CHECK-SAME: %[[CON0]]#1 as %[[CON0_1_CAPTURE:.+]]: !stream.resource{%c20}) + // CHECK-SAME: -> (!stream.resource{%c1280}, !stream.resource{%c20}) { + // CHECK-NEXT: %[[DISPATCH0:.+]] = stream.async.dispatch @ex::@dispatch_0[%c1, %c1, %c1](%[[CON0_0_CAPTURE]], %[[ARG1_CON1_CAPTURE]]) + // CHECK-NEXT: %[[DISPATCH1:.+]] = stream.async.dispatch @ex::@dispatch_1[%c1, %c1, %c1](%[[ARG0_CON1_CAPTURE]], %[[CON0_1_CAPTURE]]) + // CHECK-NEXT: stream.yield %[[DISPATCH0]], %[[DISPATCH1]] + + // CHECK: %[[DISPATCH2:.+]] = stream.async.dispatch @ex::@dispatch_2[%c1, %c1, %c1](%[[CON1]]#0, %[[CON1]]#1) + // CHECK-NEXT: stream.yield %[[DISPATCH2]] + + %1 = stream.async.splat %cst : f32 -> !stream.resource{%c1280} + %2 = stream.async.dispatch @ex::@dispatch_0[%c1, %c1, %c1](%1, %arg2) : (!stream.resource{%c1280}, !stream.resource{%c80}) -> %1{%c1280} + %3 = stream.async.splat %cst : f32 -> !stream.resource{%c20} + %4 = stream.async.dispatch @ex::@dispatch_1[%c1, %c1, %c1](%arg3, %3) : (!stream.resource{%c20}, !stream.resource{%c20}) -> %3{%c20} + %5 = stream.async.dispatch @ex::@dispatch_2[%c1, %c1, %c1](%2, %4) : (!stream.resource{%c1280}, !stream.resource{%c20}) -> !stream.resource{%c20} + stream.yield %5 : !stream.resource{%c20} + } => !stream.timepoint + %0 = stream.timepoint.await %result_timepoint => %results : !stream.resource{%c20} + return %0 : !stream.resource +} diff --git a/iree/compiler/Dialect/Stream/Transforms/test/schedule_execution.mlir b/iree/compiler/Dialect/Stream/Transforms/test/schedule_execution.mlir new file mode 100644 index 000000000000..4f1b70b26711 --- /dev/null +++ b/iree/compiler/Dialect/Stream/Transforms/test/schedule_execution.mlir @@ -0,0 +1,104 @@ +// RUN: iree-opt -split-input-file -pass-pipeline="builtin.func(iree-stream-schedule-execution)" %s | IreeFileCheck %s + +// Tests basic partitioning of multiple ops. + +// CHECK-LABEL: @partitioning +// CHECK-SAME: (%[[ARG0:.+]]: !stream.resource, %[[ARG1:.+]]: !stream.resource) +func @partitioning(%arg0: !stream.resource, %arg1: !stream.resource) -> !stream.resource { + %c1 = arith.constant 1 : index + %c20 = arith.constant 20 : index + %c80 = arith.constant 80 : index + %c1280 = arith.constant 1280 : index + %cst = arith.constant 0x7F800000 : f32 + // CHECK: %[[RESULT:.+]], %[[TIMEPOINT:.+]] = stream.async.execute + // CHECK-SAME: with(%[[ARG1]] as %[[ARG1_CAPTURE:.+]]: !stream.resource{%c80}, + // CHECK-SAME: %[[ARG0]] as %[[ARG0_CAPTURE:.+]]: !stream.resource{%c20}) + // CHECK-SAME: -> !stream.resource{%c20} { + // CHECK-NEXT: %[[SPLAT0:.+]] = stream.async.splat + %2 = stream.async.splat %cst : f32 -> !stream.resource{%c1280} + // CHECK-NEXT: %[[DISPATCH0:.+]] = stream.async.dispatch @ex::@dispatch_0[%c1, %c1, %c1](%[[SPLAT0]], %[[ARG1_CAPTURE]]) : (!stream.resource{%c1280}, !stream.resource{%c80}) -> %[[SPLAT0]]{%c1280} + %3 = stream.async.dispatch @ex::@dispatch_0[%c1, %c1, %c1](%2, %arg1) : (!stream.resource{%c1280}, !stream.resource{%c80}) -> %2{%c1280} + // CHECK-NEXT: %[[SPLAT1:.+]] = stream.async.splat + %4 = stream.async.splat %cst : f32 -> !stream.resource{%c20} + // CHECK-NEXT: %[[DISPATCH1:.+]] = stream.async.dispatch @ex::@dispatch_1[%c1, %c1, %c1](%[[ARG0_CAPTURE]], %[[SPLAT1]]) : (!stream.resource{%c20}, !stream.resource{%c20}) -> %[[SPLAT1]]{%c20} + %5 = stream.async.dispatch @ex::@dispatch_1[%c1, %c1, %c1](%arg0, %4) : (!stream.resource{%c20}, !stream.resource{%c20}) -> %4{%c20} + // CHECK-NEXT: %[[DISPATCH2:.+]] = stream.async.dispatch @ex::@dispatch_2[%c1, %c1, %c1](%[[DISPATCH0]], %[[DISPATCH1]]) : (!stream.resource{%c1280}, !stream.resource{%c20}) -> !stream.resource{%c20} + %6 = stream.async.dispatch @ex::@dispatch_2[%c1, %c1, %c1](%3, %5) : (!stream.resource{%c1280}, !stream.resource{%c20}) -> !stream.resource{%c20} + // CHECK-NEXT: stream.yield %[[DISPATCH2]] : !stream.resource{%c20} + // CHECK-NEXT: } => !stream.timepoint + // CHECK-NEXT: %[[READY:.+]] = stream.timepoint.await %[[TIMEPOINT]] => %[[RESULT]] : !stream.resource{%c20} + // CHECK-NEXT: return %[[READY]] + return %6 : !stream.resource +} + +// ----- + +// Tests that ops in multiple blocks are partitioned independently and that +// timepoints are chained between the partitions. Note that the dispatches +// happen in-place on the splat and we expect the execution regions to be tied. + +// CHECK-LABEL: @partitionWithinBlocks +func @partitionWithinBlocks(%cond: i1) -> !stream.resource { + %c1 = arith.constant 1 : index + %c1280 = arith.constant 1280 : index + %cst = arith.constant 0x7F800000 : f32 + // CHECK: %[[SPLAT:.+]], %[[SPLAT_TIMEPOINT:.+]] = stream.async.execute + // CHECK: stream.async.splat + %splat = stream.async.splat %cst : f32 -> !stream.resource{%c1280} + // CHECK: cond_br + cond_br %cond, ^bb1, ^bb2 +^bb1: + // CHECK: %[[BB1_RESULT:.+]], %[[BB1_TIMEPOINT:.+]] = stream.async.execute await(%[[SPLAT_TIMEPOINT]]) => + // CHECK-SAME: with(%[[SPLAT]] as %[[BB1_SPLAT:.+]]: !stream.resource{%c1280}) + // CHECK-SAME: -> %[[SPLAT]]{%c1280} + // CHECK: stream.async.dispatch @ex::@dispatch_0[%c1, %c1, %c1](%[[BB1_SPLAT]]) : (!stream.resource{%c1280}) -> %[[BB1_SPLAT]]{%c1280} + %3 = stream.async.dispatch @ex::@dispatch_0[%c1, %c1, %c1](%splat) : (!stream.resource{%c1280}) -> %splat{%c1280} + // CHECK: %[[BB1_READY:.+]] = stream.timepoint.await %[[BB1_TIMEPOINT]] => %[[BB1_RESULT]] + // CHECK: return %[[BB1_READY]] + return %3 : !stream.resource +^bb2: + // CHECK: %[[BB2_RESULT:.+]], %[[BB2_TIMEPOINT:.+]] = stream.async.execute await(%[[SPLAT_TIMEPOINT]]) => + // CHECK-SAME: with(%[[SPLAT]] as %[[BB2_SPLAT:.+]]: !stream.resource{%c1280}) + // CHECK-SAME: -> %[[SPLAT]]{%c1280} + // CHECK: stream.async.dispatch @ex::@dispatch_1[%c1, %c1, %c1](%[[BB2_SPLAT]]) : (!stream.resource{%c1280}) -> %[[BB2_SPLAT]]{%c1280} + %4 = stream.async.dispatch @ex::@dispatch_1[%c1, %c1, %c1](%splat) : (!stream.resource{%c1280}) -> %splat{%c1280} + // CHECK: %[[BB2_READY:.+]] = stream.timepoint.await %[[BB2_TIMEPOINT]] => %[[BB2_RESULT]] + // CHECK: return %[[BB2_READY]] + return %4 : !stream.resource +} + +// ----- + +// Tests a complex device->host->device sequence gets turned into the proper +// execute->await->execute. These data-dependent operations can happen in a +// single block and break the assumption that one block == one partition. + +// CHECK-LABEL: @deviceHostDevice +func @deviceHostDevice() -> !stream.resource { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c123_i8 = arith.constant 123 : i8 + // CHECK: %[[RESULT_D2H:.+]], %[[TIMEPOINT_D2H:.+]] = stream.async.execute with() + // CHECK-SAME: -> !stream.resource{%c1} + // CHECK-NEXT: %[[SPLAT:.+]] = stream.async.splat %c123_i8 + %0 = stream.async.splat %c123_i8 : i8 -> !stream.resource{%c1} + // CHECK-NEXT: %[[TRANSFER_D2H:.+]] = stream.async.transfer %[[SPLAT]] + %1 = stream.async.transfer %0 : !stream.resource{%c1} -> !stream.resource{%c1} + // CHECK-NEXT: stream.yield %[[TRANSFER_D2H]] + // CHECK: %[[READY_D2H:.+]] = stream.timepoint.await %[[TIMEPOINT_D2H]] => %[[RESULT_D2H]] : !stream.resource{%c1} + // CHECK: %[[LOAD:.+]] = stream.async.load %[[READY_D2H]] + %2 = stream.async.load %1[%c0] : !stream.resource{%c1} -> i8 + // CHECK: %[[ADD:.+]] = arith.addi %[[LOAD]], %[[LOAD]] + %3 = arith.addi %2, %2 : i8 + // CHECK: %[[STORE:.+]] = stream.async.store %[[ADD]], %[[READY_D2H]] + %4 = stream.async.store %3, %1[%c0] : i8 -> !stream.resource{%c1} + // CHECK: %[[RESULT_H2D:.+]], %[[TIMEPOINT_H2D:.+]] = stream.async.execute + // CHECK-SAME: with(%[[STORE]] as %[[STORE_CAPTURE:.+]]: !stream.resource{%c1}) + // CHECK-SAME: -> !stream.resource{%c1} + // CHECK-NEXT: %[[TRANSFER_H2D:.+]] = stream.async.transfer %[[STORE_CAPTURE]] + %5 = stream.async.transfer %4 : !stream.resource{%c1} -> !stream.resource{%c1} + // CHECK-NEXT: stream.yield %[[TRANSFER_H2D]] + // CHECK: %[[READY_H2D:.+]] = stream.timepoint.await %[[TIMEPOINT_H2D]] => %[[RESULT_H2D]] : !stream.resource{%c1} + // CHECK: return %[[READY_H2D]] + return %5 : !stream.resource +}