Skip to content

Commit

Permalink
[mlir][ArmSME] Support filling liveness 'holes' in the tile allocator (
Browse files Browse the repository at this point in the history
…llvm#98350)

Holes in a live range are points where the corresponding value does not
need to be in a tile/register. If the tile allocator keeps track of
these holes it can reuse tiles for more values (avoiding spills).

Take this simple example:

```mlir
func.func @example(%cond: i1) {
  %tileA = arm_sme.get_tile : vector<[4]x[4]xf32>
  cf.cond_br %cond, ^bb2, ^bb1
^bb1:
  // If we end up here we never use %tileA again!
  "test.some_use"(%tileB) : (vector<[4]x[4]xf32>) -> ()
  cf.br ^bb3
^bb2:
  "test.some_use"(%tileA) : (vector<[4]x[4]xf32>) -> ()
  cf.br ^bb3
^bb3:
  return
}
```

If you were to calculate the liveness of %tileA and %tileB. You'd see
there is a hole in the liveness of %tileA in bb1:

```
      %tileA  %tileB
^bb0:  Live
^bb1:          Live
^bb2:  Live
```

The tile allocator can make use of that hole and reuse the tile ID it
assigned to %tileA for %tileB.
  • Loading branch information
MacDue authored and sgundapa committed Jul 23, 2024
1 parent 91e327b commit 80da4b2
Show file tree
Hide file tree
Showing 2 changed files with 283 additions and 29 deletions.
134 changes: 105 additions & 29 deletions mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,18 @@ class TileAllocator {
return failure();
}

/// Acquires a specific tile ID. Asserts the tile is initially free.
void acquireTileId(ArmSMETileType tileType, unsigned tileId) {
TileMask tileMask = getMasks(tileType)[tileId];
assert((tilesInUse & tileMask) == TileMask::kNone &&
"cannot acquire allocated tile!");
tilesInUse |= tileMask;
}

/// Releases a previously allocated tile ID.
void releaseTileId(ArmSMETileType tileType, unsigned tileId) {
TileMask tileMask = getMasks(tileType)[tileId];
assert((tilesInUse & tileMask) != TileMask::kNone &&
assert((tilesInUse & tileMask) == tileMask &&
"cannot release unallocated tile!");
tilesInUse ^= tileMask;
}
Expand Down Expand Up @@ -289,6 +297,11 @@ struct LiveRange {
.valid();
}

/// Returns true if this range is active at `point` in the program.
bool overlaps(uint64_t point) const {
return ranges->lookup(point) == kValidLiveRange;
}

/// Unions this live range with `otherRange`, aborts if the ranges overlap.
void unionWith(LiveRange const &otherRange) {
for (auto it = otherRange.ranges->begin(); it != otherRange.ranges->end();
Expand Down Expand Up @@ -488,76 +501,139 @@ coalesceTileLiveRanges(DenseMap<Value, LiveRange> &initialLiveRanges) {
return std::move(coalescedLiveRanges);
}

/// Choose a live range to spill (via some heuristics). This picks either an
/// active live range from `activeRanges` or the new live range `newRange`.
LiveRange *chooseSpillUsingHeuristics(ArrayRef<LiveRange *> activeRanges,
LiveRange *newRange) {
/// Choose a live range to spill (via some heuristics). This picks either a live
/// range from `overlappingRanges`, or the new live range `newRange`.
template <typename OverlappingRangesIterator>
LiveRange *
chooseSpillUsingHeuristics(OverlappingRangesIterator overlappingRanges,
LiveRange *newRange) {
// Heuristic: Spill trivially copyable operations (usually free).
auto isTrivialSpill = [&](LiveRange *allocatedRange) {
return isTileTypeGreaterOrEqual(allocatedRange->getTileType(),
auto isTrivialSpill = [&](LiveRange &allocatedRange) {
return isTileTypeGreaterOrEqual(allocatedRange.getTileType(),
newRange->getTileType()) &&
allocatedRange->values.size() == 1 &&
allocatedRange.values.size() == 1 &&
isTriviallyCloneableTileOp(
allocatedRange->values[0]
.getDefiningOp<ArmSMETileOpInterface>());
allocatedRange.values[0].getDefiningOp<ArmSMETileOpInterface>());
};
if (isTrivialSpill(newRange))
if (isTrivialSpill(*newRange))
return newRange;
auto trivialSpill = llvm::find_if(activeRanges, isTrivialSpill);
if (trivialSpill != activeRanges.end())
return *trivialSpill;
auto trivialSpill = llvm::find_if(overlappingRanges, isTrivialSpill);
if (trivialSpill != overlappingRanges.end())
return &*trivialSpill;

// Heuristic: Spill the range that ends last (with a compatible tile type).
auto isSmallerTileTypeOrEndsEarlier = [](LiveRange *a, LiveRange *b) {
return !isTileTypeGreaterOrEqual(a->getTileType(), b->getTileType()) ||
a->end() < b->end();
auto isSmallerTileTypeOrEndsEarlier = [](LiveRange &a, LiveRange &b) {
return !isTileTypeGreaterOrEqual(a.getTileType(), b.getTileType()) ||
a.end() < b.end();
};
LiveRange *lastActiveLiveRange = *std::max_element(
activeRanges.begin(), activeRanges.end(), isSmallerTileTypeOrEndsEarlier);
if (!isSmallerTileTypeOrEndsEarlier(lastActiveLiveRange, newRange))
return lastActiveLiveRange;
LiveRange &latestEndingLiveRange =
*std::max_element(overlappingRanges.begin(), overlappingRanges.end(),
isSmallerTileTypeOrEndsEarlier);
if (!isSmallerTileTypeOrEndsEarlier(latestEndingLiveRange, *newRange))
return &latestEndingLiveRange;
return newRange;
}

/// Greedily allocate tile IDs to live ranges. Spill using simple heuristics.
/// Note: This does not attempt to fill holes in active live ranges.
void allocateTilesToLiveRanges(
ArrayRef<LiveRange *> liveRangesSortedByStartPoint) {
TileAllocator tileAllocator;
// `activeRanges` = Live ranges that need to be in a tile at the
// `currentPoint` in the program.
SetVector<LiveRange *> activeRanges;
// `inactiveRanges` = Live ranges that _do not_ need to be in a tile
// at the `currentPoint` in the program but could become active again later.
// An inactive section of a live range can be seen as a 'hole' in the live
// range, where it is possible to reuse the live range's tile ID _before_ it
// has ended. By identifying 'holes', the allocator can reuse tiles more
// often, which helps avoid costly tile spills.
SetVector<LiveRange *> inactiveRanges;
for (LiveRange *nextRange : liveRangesSortedByStartPoint) {
// Release tile IDs from live ranges that have ended.
auto currentPoint = nextRange->start();
// 1. Update the `activeRanges` at `currentPoint`.
activeRanges.remove_if([&](LiveRange *activeRange) {
if (activeRange->end() <= nextRange->start()) {
// Check for live ranges that have expired.
if (activeRange->end() <= currentPoint) {
tileAllocator.releaseTileId(activeRange->getTileType(),
*activeRange->tileId);
return true;
}
// Check for live ranges that have become inactive.
if (!activeRange->overlaps(currentPoint)) {
tileAllocator.releaseTileId(activeRange->getTileType(),
*activeRange->tileId);
inactiveRanges.insert(activeRange);
return true;
}
return false;
});
// 2. Update the `inactiveRanges` at `currentPoint`.
inactiveRanges.remove_if([&](LiveRange *inactiveRange) {
// Check for live ranges that have expired.
if (inactiveRange->end() <= currentPoint) {
return true;
}
// Check for live ranges that have become active.
if (inactiveRange->overlaps(currentPoint)) {
tileAllocator.acquireTileId(inactiveRange->getTileType(),
*inactiveRange->tileId);
activeRanges.insert(inactiveRange);
return true;
}
return false;
});

// 3. Collect inactive live ranges that overlap with the new live range.
// Note: The overlap checks in steps 1 and 2 only look at the `currentPoint`
// whereas this checks if there is an overlap at any future point too.
SmallVector<LiveRange *> overlappingInactiveRanges;
for (LiveRange *inactiveRange : inactiveRanges) {
if (inactiveRange->overlaps(*nextRange)) {
// We need to reserve the tile IDs of overlapping inactive ranges to
// prevent two (overlapping) live ranges from getting the same tile ID.
tileAllocator.acquireTileId(inactiveRange->getTileType(),
*inactiveRange->tileId);
overlappingInactiveRanges.push_back(inactiveRange);
}
}

// Allocate a tile ID to `nextRange`.
// 4. Allocate a tile ID to `nextRange`.
auto rangeTileType = nextRange->getTileType();
auto tileId = tileAllocator.allocateTileId(rangeTileType);
if (succeeded(tileId)) {
nextRange->tileId = *tileId;
} else {
// Create an iterator over all overlapping live ranges.
auto allOverlappingRanges = llvm::concat<LiveRange>(
llvm::make_pointee_range(activeRanges.getArrayRef()),
llvm::make_pointee_range(overlappingInactiveRanges));
// Choose an overlapping live range to spill.
LiveRange *rangeToSpill =
chooseSpillUsingHeuristics(activeRanges.getArrayRef(), nextRange);
chooseSpillUsingHeuristics(allOverlappingRanges, nextRange);
if (rangeToSpill != nextRange) {
// Spill an active live range (so release its tile ID first).
// Spill an (in)active live range (so release its tile ID first).
tileAllocator.releaseTileId(rangeToSpill->getTileType(),
*rangeToSpill->tileId);
activeRanges.remove(rangeToSpill);
// This will always succeed after a spill (of an active live range).
nextRange->tileId = *tileAllocator.allocateTileId(rangeTileType);
// Remove the live range from the active/inactive sets.
if (!activeRanges.remove(rangeToSpill)) {
bool removed = inactiveRanges.remove(rangeToSpill);
assert(removed && "expected a range to be removed!");
}
}
rangeToSpill->tileId = tileAllocator.allocateInMemoryTileId();
}

// Insert the live range into the active ranges.
// 5. Insert the live range into the active ranges.
if (nextRange->tileId < kInMemoryTileIdBase)
activeRanges.insert(nextRange);

// 6. Release tiles reserved for inactive live ranges (in step 3).
for (LiveRange *range : overlappingInactiveRanges) {
if (*range->tileId < kInMemoryTileIdBase)
tileAllocator.releaseTileId(range->getTileType(), *range->tileId);
}
}
}

Expand Down
178 changes: 178 additions & 0 deletions mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -430,3 +430,181 @@ func.func @cond_branch_with_backedge(%slice: vector<[4]xf32>) {
// Live here: %finalTileA, %finalTileB, %finalTileC, %finalTileD
return
}

// -----

// CHECK-LIVE-RANGE-LABEL: @fill_holes_in_tile_liveness
// CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
// CHECK-LIVE-RANGE: ^bb0:
// CHECK-LIVE-RANGE-NEXT: S arm_sme.get_tile
// CHECK-LIVE-RANGE-NEXT: E cf.cond_br
// CHECK-LIVE-RANGE-NEXT: ^bb1:
// CHECK-LIVE-RANGE-NEXT: S arm_sme.get_tile
// CHECK-LIVE-RANGE-NEXT: | test.dummy
// CHECK-LIVE-RANGE-NEXT: E test.some_use
// CHECK-LIVE-RANGE-NEXT: cf.br
// CHECK-LIVE-RANGE-NEXT: ^bb2:
// CHECK-LIVE-RANGE-NEXT: | test.dummy
// CHECK-LIVE-RANGE-NEXT: | test.dummy
// CHECK-LIVE-RANGE-NEXT: | test.dummy
// CHECK-LIVE-RANGE-NEXT: E test.some_use
// CHECK-LIVE-RANGE-NEXT: cf.br

// Here there's a 'hole' in the liveness of %tileA (in bb1) where another value
// can reuse the tile ID assigned to %tileA. The liveness for %tileB is
// entirely within the 'hole' in %tileA's live range, so %tileB should get the
// same tile ID as %tileA.

// CHECK-LABEL: @fill_holes_in_tile_liveness
func.func @fill_holes_in_tile_liveness(%cond: i1) {
// CHECK: arm_sme.get_tile {tile_id = [[TILE_ID_A:.*]] : i32}
%tileA = arm_sme.get_tile : vector<[4]x[4]xf32>
cf.cond_br %cond, ^bb2, ^bb1
^bb1:
// CHECK: arm_sme.get_tile {tile_id = [[TILE_ID_A]] : i32}
%tileB = arm_sme.get_tile : vector<[4]x[4]xf32>
"test.dummy"(): () -> ()
"test.some_use"(%tileB) : (vector<[4]x[4]xf32>) -> ()
cf.br ^bb3
^bb2:
"test.dummy"(): () -> ()
"test.dummy"(): () -> ()
"test.dummy"(): () -> ()
"test.some_use"(%tileA) : (vector<[4]x[4]xf32>) -> ()
cf.br ^bb3
^bb3:
return
}

// -----

// CHECK-LIVE-RANGE-LABEL: @holes_in_tile_liveness_inactive_overlaps
// CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
// CHECK-LIVE-RANGE: ^bb0:
// CHECK-LIVE-RANGE-NEXT: S arm_sme.get_tile
// CHECK-LIVE-RANGE-NEXT: E cf.cond_br
// CHECK-LIVE-RANGE-NEXT: ^bb1:
// CHECK-LIVE-RANGE-NEXT: S arm_sme.get_tile
// CHECK-LIVE-RANGE-NEXT: | test.dummy
// CHECK-LIVE-RANGE-NEXT: | test.some_use
// CHECK-LIVE-RANGE-NEXT: | arm_sme.copy_tile
// CHECK-LIVE-RANGE-NEXT: E cf.br
// CHECK-LIVE-RANGE-NEXT: ^bb2:
// CHECK-LIVE-RANGE-NEXT: | test.dummy
// CHECK-LIVE-RANGE-NEXT: | test.dummy
// CHECK-LIVE-RANGE-NEXT: | test.dummy
// CHECK-LIVE-RANGE-NEXT: |S arm_sme.get_tile
// CHECK-LIVE-RANGE-NEXT: E| test.some_use
// CHECK-LIVE-RANGE-NEXT: | arm_sme.copy_tile
// CHECK-LIVE-RANGE-NEXT: E cf.br
// CHECK-LIVE-RANGE-NEXT: ^bb3:
// CHECK-LIVE-RANGE-NEXT: E test.some_use
// CHECK-LIVE-RANGE-NEXT: func.return

// This tests an edge case in inactive live ranges. The first live range is
// inactive at the start of ^bb1. If the tile allocator did not check if the
// second live range overlapped the first it would wrongly re-use tile ID 0
// (as the first live range is inactive so tile ID 0 is free). This would mean
// in ^bb2 two overlapping live ranges would have the same tile ID (bad!).

// CHECK-LABEL: @holes_in_tile_liveness_inactive_overlaps
func.func @holes_in_tile_liveness_inactive_overlaps(%cond: i1) {
// CHECK: arm_sme.get_tile {tile_id = 0 : i32}
%tileA = arm_sme.get_tile : vector<[4]x[4]xf32>
cf.cond_br %cond, ^bb2, ^bb1
^bb1:
// CHECK: arm_sme.get_tile {tile_id = 1 : i32}
%tileB = arm_sme.get_tile : vector<[4]x[4]xf32>
"test.dummy"(): () -> ()
"test.some_use"(%tileB) : (vector<[4]x[4]xf32>) -> ()
cf.br ^bb3(%tileB: vector<[4]x[4]xf32>)
^bb2:
"test.dummy"(): () -> ()
"test.dummy"(): () -> ()
"test.dummy"(): () -> ()
// CHECK: arm_sme.get_tile {tile_id = 1 : i32}
%tileC = arm_sme.get_tile : vector<[4]x[4]xf32>
"test.some_use"(%tileA) : (vector<[4]x[4]xf32>) -> ()
cf.br ^bb3(%tileC: vector<[4]x[4]xf32>)
^bb3(%tile: vector<[4]x[4]xf32>):
"test.some_use"(%tile) : (vector<[4]x[4]xf32>) -> ()
return
}

// -----

// This is the same as the previous example, but changes the tile types to
// vector<[16]x[16]xi8>. This means in bb1 the allocator will need to spill the
// first live range (which is inactive).

// Note: The live ranges are the same as the previous example (so are not checked).

// CHECK-LABEL: @spill_inactive_live_range
func.func @spill_inactive_live_range(%cond: i1) {
// CHECK: arm_sme.get_tile {tile_id = 16 : i32}
%tileA = arm_sme.get_tile : vector<[16]x[16]xi8>
cf.cond_br %cond, ^bb2, ^bb1
^bb1:
// CHECK: arm_sme.get_tile {tile_id = 0 : i32}
%tileB = arm_sme.get_tile : vector<[16]x[16]xi8>
"test.dummy"(): () -> ()
"test.some_use"(%tileB) : (vector<[16]x[16]xi8>) -> ()
cf.br ^bb3(%tileB: vector<[16]x[16]xi8>)
^bb2:
"test.dummy"(): () -> ()
"test.dummy"(): () -> ()
"test.dummy"(): () -> ()
// CHECK: arm_sme.get_tile {tile_id = 0 : i32}
%tileC = arm_sme.get_tile : vector<[16]x[16]xi8>
"test.some_use"(%tileA) : (vector<[16]x[16]xi8>) -> ()
cf.br ^bb3(%tileC: vector<[16]x[16]xi8>)
^bb3(%tile: vector<[16]x[16]xi8>):
"test.some_use"(%tile) : (vector<[16]x[16]xi8>) -> ()
return
}

// -----

// CHECK-LIVE-RANGE-LABEL: @reactivate_inactive_live_range
// CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
// CHECK-LIVE-RANGE: ^bb0:
// CHECK-LIVE-RANGE-NEXT: S arm_sme.get_tile
// CHECK-LIVE-RANGE-NEXT: E cf.cond_br
// CHECK-LIVE-RANGE-NEXT: ^bb1:
// CHECK-LIVE-RANGE-NEXT: S arm_sme.get_tile
// CHECK-LIVE-RANGE-NEXT: | test.dummy
// CHECK-LIVE-RANGE-NEXT: E test.some_use
// CHECK-LIVE-RANGE-NEXT: cf.br
// CHECK-LIVE-RANGE-NEXT: ^bb2:
// CHECK-LIVE-RANGE-NEXT: | S arm_sme.get_tile
// CHECK-LIVE-RANGE-NEXT: | | test.dummy
// CHECK-LIVE-RANGE-NEXT: | | test.dummy
// CHECK-LIVE-RANGE-NEXT: | E test.some_use
// CHECK-LIVE-RANGE-NEXT: E test.some_use
// CHECK-LIVE-RANGE-NEXT: cf.br

// Here the live range for %tileA becomes inactive in bb1 (so %tileB gets tile
// ID 0 too). Then in bb2 the live range for tileA is reactivated as it overlaps
// with the start of %tileC's live range (which means %tileC gets tile ID 1).

func.func @reactivate_inactive_live_range(%cond: i1) {
// CHECK: arm_sme.get_tile {tile_id = 0 : i32}
%tileA = arm_sme.get_tile : vector<[4]x[4]xf32>
cf.cond_br %cond, ^bb2, ^bb1
^bb1:
// CHECK: arm_sme.get_tile {tile_id = 0 : i32}
%tileB = arm_sme.get_tile : vector<[16]x[16]xi8>
"test.dummy"(): () -> ()
"test.some_use"(%tileB) : (vector<[16]x[16]xi8>) -> ()
cf.br ^bb3
^bb2:
// CHECK: arm_sme.get_tile {tile_id = 1 : i32}
%tileC = arm_sme.get_tile : vector<[4]x[4]xf32>
"test.dummy"(): () -> ()
"test.dummy"(): () -> ()
"test.some_use"(%tileC) : (vector<[4]x[4]xf32>) -> ()
"test.some_use"(%tileA) : (vector<[4]x[4]xf32>) -> ()
cf.br ^bb3
^bb3:
return
}

0 comments on commit 80da4b2

Please sign in to comment.