Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Transform] Hoist thread-local allocator within the nested parallel loops #283

Open
wants to merge 41 commits into
base: main
Choose a base branch
from

Conversation

ciyongch
Copy link

This PR is to track #120

This is another implementation of hoisting thread-local allocator base on the new memref-merge design, so it depends on #44

@@ -30,6 +30,69 @@ using namespace special_ticks;
/// and default memory space.
static bool isMemRefTypeOk(MemRefType type) { return type.hasStaticShape(); }

static inline int64_t getSizeInBytes(MemRefType &memType) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we use sub-class overriding instead of directly changing the to-be-upstreamed code? It can prove that our tick-based interfaces is extendable. And it can decouple the downstream logic from the upstream part.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My consideration is: this is not a big enhancement in addition to the existing general allocator hoist logic within this "framework", if we're going to unify all the allocator behavior in a separate extension instead of mixing them, I think we can go with this way.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will be hard to extract the downstream logic when the upstream PR got merged and if we would like to rebase. :) Just a suggestion. it is up to you anyway :)

isa<arith::ConstantIndexOp>(ub.getDefiningOp()));
});

isStatic &= llvm::all_of(lowerBounds, [](Value &lb) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we also check the step? I am also not sure if &= will short-cut the evaluation of llvm::all_of(...) if isStatic is false.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I can add the check for the step, and use early return when the expression return false.

%alloc_0 = memref.alloc() : memref<8xf32>
%1 = scf.for %k = %lb to %ub step %step
iter_args(%iterBuf = %arg0) -> (memref<2xf32>) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to confirm, what are we testing here, to set iter_args to arg0? Is it to check if the scheduler can skip complex lifetime?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Neither, the case is simply to demonstrate the scenario of the mixed usage of scf.forall and scf.for, the only testing purpose is for those allocators in the case. And BTW, iter_args still supports memref in addition to tensor.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will not generate loops with memref in iter_args after bufferization. This feature is for dynamic allocations which will be identified as complex access.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your explanation, the complex access is not the main testing purpose for the case here. And it was originally borrowed from: https://github.com/llvm/llvm-project/blob/main/mlir/test/Dialect/Bufferization/Transforms/buffer-loop-hoisting.mlir#L163-L181.

while (parent) {
if (auto forallOp = dyn_cast<scf::ForallOp>(parent)) {
if (isForallLoopBoundStatic(forallOp)) {
SmallVector<Value> upperBounds = forallOp.getUpperBound(builder);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggest to use getStaticUpperBound instead.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems getStaticUpperBound() always return the "std::numeric_limits<int64_t>::min()", so I keep the current impl.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. That may be we are using constant ops instead of attrs for constant bounds. Another possible way may be to check both MiexedValue and Value being constants.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't make much difference, both ways shall be fine?

llvm::zip(forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
forallOp.getMixedStep())) {
std::optional<int64_t> ubConst = getConstantIntValue(ub);
return ubConst.has_value() && isConstantIntValue(lb, 0) &&
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not support the case with lb != 0 or step != 1? We may have loop like

scf.forall (%arg7) = (0) to (512) step (32)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not a hard limitation, it's now updated to support this case.


// Get the total number of threads from the outermost to the current level of
// the parallel loop that the allocation located in.
int64_t numThreads = 1;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The calculation of numThreads could directly use the upstream util constantTripCount.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the tips, changed to use the util function.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants