Skip to content

Commit

Permalink
[LLVMGPU] Delete dead code in prefetch pass (#18543)
Browse files Browse the repository at this point in the history
The multi-stage prefetching was not used/tested.

Also fix some typos.
  • Loading branch information
kuhar committed Sep 17, 2024
1 parent 6a44005 commit ad8f814
Showing 1 changed file with 13 additions and 100 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ class LoopPrefetcher {
prefetcher.mapping = SmallVector<IRMapping>(4);
prefetcher.forOp = op;
prefetcher.lb = prefetcher.ub = prefetcher.step = 0;
prefetcher.singleStage = true;

if (failed(prefetcher.initializeLoopInfo())) {
LDBG("Failed to initialize loop info (unsupported loop)");
Expand All @@ -99,44 +98,20 @@ class LoopPrefetcher {

// Emits the prologue before the main pipelined loop and returns the read
// results to be passed to the main loop as initial loop carried values, and
// their useages by corresponding writes in the main loop.
// their usages by corresponding writes in the main loop.
std::tuple<SmallVector<Value>, SmallVector<Value>>
emitPrologue(RewriterBase &rewriter) {
Location loc = forOp.getLoc();
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, lb);
Value one = rewriter.create<arith::ConstantIndexOp>(loc, lb + step);
SmallVector<Value> iterArgs;
SmallVector<Value> readResults;
SmallVector<Value> writeArgs;

if (singleStage) {
// If we only prefetch one step ahead, we can directly write in the
// prologue and use the shared memory to communicate data instead of the
// loop carried values.
// Read (0)
emitRead(mapping[0], rewriter, zero);
// Write(0)
emitWrite(mapping[0], rewriter, zero);
return {iterArgs, writeArgs};
}

// Read(0).
iterArgs = emitRead(mapping[0], rewriter, zero);
// Read(1).
readResults = emitRead(mapping[1], rewriter, one);
llvm::append_range(iterArgs, readResults);

// Collect the values to be used as write args.
for (Operation *op : readStage) {
if (auto transferReadOp = dyn_cast<vector::TransferReadOp>(op)) {
for (Operation *user : transferReadOp.getResult().getUsers()) {
if (auto writeOp = dyn_cast<vector::TransferWriteOp>(user)) {
writeArgs.push_back(writeOp.getVector());
}
}
}
}

// Directly write in the prologue and use the shared memory to communicate
// data instead of the loop carried values. Read (0)
emitRead(mapping[0], rewriter, zero);
// Write(0)
emitWrite(mapping[0], rewriter, zero);
return {iterArgs, writeArgs};
}

Expand All @@ -145,7 +120,7 @@ class LoopPrefetcher {
SmallVector<Value> &newIterArgs,
SmallVector<Value> &writeArgs) {
Location loc = forOp.getLoc();
int64_t newUpperBound = singleStage ? (ub - step) : (ub - 2 * step);
int64_t newUpperBound = ub - step;
auto newUb = rewriter.create<arith::ConstantIndexOp>(loc, newUpperBound);

// Keep original iter args and then add some for what's being loaded to
Expand All @@ -154,9 +129,6 @@ class LoopPrefetcher {
llvm::append_range(iterArgs, newIterArgs);

Value newStep = forOp.getStep();
if (!singleStage) {
newStep = rewriter.create<arith::AddIOp>(loc, newStep, newStep);
}
auto newForOp = rewriter.create<scf::ForOp>(loc, forOp.getLowerBound(),
newUb, newStep, iterArgs);

Expand All @@ -165,19 +137,6 @@ class LoopPrefetcher {
if (!newForOp.getBody()->empty())
rewriter.eraseOp(newForOp.getBody()->getTerminator());

if (singleStage)
return newForOp;

SmallVector<Value> targetValues(writeArgs.size());
for (size_t i = 0, e = writeArgs.size(); i != e; ++i)
targetValues[i] = newForOp.getRegionIterArg(i + 1);

createWriteMappings(writeArgs, targetValues, mapping[0]);

for (size_t i = 0, e = writeArgs.size(); i != e; ++i)
targetValues[i] = newForOp.getRegionIterArg(i + e + 1);

createWriteMappings(writeArgs, targetValues, mapping[1]);
return newForOp;
}

Expand All @@ -188,8 +147,6 @@ class LoopPrefetcher {
Value indVar = newForOp.getInductionVar();
Value increment = rewriter.create<arith::ConstantIndexOp>(loc, step);
Value iPlusOne = rewriter.create<arith::AddIOp>(loc, indVar, increment);
Value iPlusTwo = rewriter.create<arith::AddIOp>(loc, iPlusOne, increment);
Value iPlusThree = rewriter.create<arith::AddIOp>(loc, iPlusTwo, increment);

for (int i = 0; i < 3; ++i) {
for (auto [idx, arg] : llvm::enumerate(forOp.getRegionIterArgs())) {
Expand All @@ -198,29 +155,13 @@ class LoopPrefetcher {
}

SmallVector<Value> readRegisters, moreRegisters;
if (singleStage) {
emitRead(mapping[1], rewriter, iPlusOne);
emitBarrier(loc, rewriter);
emitCompute(mapping[0], rewriter, indVar);
emitBarrier(loc, rewriter);
emitWrite(mapping[1], rewriter, iPlusOne);
updateYield(mapping[0], readRegisters, rewriter);
return;
}

emitWrite(mapping[0], rewriter, indVar);
readRegisters = emitRead(mapping[2], rewriter, iPlusTwo);
emitRead(mapping[1], rewriter, iPlusOne);
emitBarrier(loc, rewriter);
auto computeResults = emitCompute(mapping[0], rewriter, indVar);
mapping[0].map(forOp.getRegionIterArg(0), computeResults[0]);
emitCompute(mapping[0], rewriter, indVar);
emitBarrier(loc, rewriter);
emitWrite(mapping[1], rewriter, iPlusOne);
moreRegisters = emitRead(mapping[3], rewriter, iPlusThree);
emitBarrier(loc, rewriter);
emitCompute(mapping[0], rewriter, iPlusOne);
emitBarrier(loc, rewriter);
readRegisters.append(moreRegisters.begin(), moreRegisters.end());
updateYield(mapping[0], readRegisters, rewriter);
return;
}

// Emits the epilogue after the main pipelined loop and returns the final
Expand All @@ -229,8 +170,6 @@ class LoopPrefetcher {
SmallVector<Value> &writeArgs) {
rewriter.setInsertionPointAfter(newForOp);
Location loc = forOp.getLoc();
Value nMinusTwo =
rewriter.create<arith::ConstantIndexOp>(loc, ub - 2 * step);
Value nMinusOne =
rewriter.create<arith::ConstantIndexOp>(loc, ub - 1 * step);

Expand All @@ -239,32 +178,8 @@ class LoopPrefetcher {
mapping[0].map(forOp.getRegionIterArg(i), newForOp.getResult(i));
}

if (singleStage) {
emitBarrier(loc, rewriter);
return emitCompute(mapping[0], rewriter, nMinusOne);
}

SmallVector<Value> targetValues(writeArgs.size());
for (size_t i = 0, e = writeArgs.size(); i != e; ++i)
targetValues[i] = newForOp.getResult(i + 1);

createWriteMappings(writeArgs, targetValues, mapping[2]);

for (size_t i = 0, e = writeArgs.size(); i != e; ++i)
targetValues[i] = newForOp.getResult(i + e + 1);

createWriteMappings(writeArgs, targetValues, mapping[3]);

emitWrite(mapping[2], rewriter, nMinusTwo);
emitBarrier(loc, rewriter);
SmallVector<Value> computeResults =
emitCompute(mapping[0], rewriter, nMinusTwo);
mapping[0].map(forOp.getRegionIterArg(0), computeResults[0]);
emitBarrier(loc, rewriter);
emitWrite(mapping[3], rewriter, nMinusOne);
emitBarrier(loc, rewriter);
computeResults = emitCompute(mapping[0], rewriter, nMinusOne);
return computeResults;
return emitCompute(mapping[0], rewriter, nMinusOne);
}

private:
Expand Down Expand Up @@ -310,7 +225,7 @@ class LoopPrefetcher {
}

// We only support loops whose bodies can be divided into 3 stages (read,
// write, compute). If there are any remaning ops with side effects (except
// write, compute). If there are any remaining ops with side effects (except
// for gpu.barrier), the loop is not supported.
LogicalResult initializeStages() {
DenseSet<Operation *> readDependencies;
Expand Down Expand Up @@ -372,7 +287,7 @@ class LoopPrefetcher {
return success();
}

/// Clones |op| and call |callback| on the cloned op's oeprands as well as any
/// Clones |op| and call |callback| on the cloned op's operands as well as any
/// operands of nested ops that 1) aren't defined within the new op or 2) are
/// block arguments.
static Operation *
Expand Down Expand Up @@ -511,8 +426,6 @@ class LoopPrefetcher {
scf::ForOp forOp;
// Original static loop range and step.
int64_t lb, ub, step;
// Whether we only prefetch one single step ahead.
bool singleStage;

// Ops in the original scf.for loop that belongs to different classes.
SmallVector<Operation *> readStage;
Expand Down

0 comments on commit ad8f814

Please sign in to comment.