-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Transform] Hoist thread-local allocator within the nested parallel loops #283
base: main
Are you sure you want to change the base?
Conversation
…ijie/mem-merge
…ijie/mem-merge
…ijie/mem-merge
…ijie/mem-merge
…ijie/mem-merge
@@ -30,6 +30,69 @@ using namespace special_ticks; | |||
/// and default memory space. | |||
static bool isMemRefTypeOk(MemRefType type) { return type.hasStaticShape(); } | |||
|
|||
static inline int64_t getSizeInBytes(MemRefType &memType) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shall we use sub-class overriding instead of directly changing the to-be-upstreamed code? It can prove that our tick-based interfaces is extendable. And it can decouple the downstream logic from the upstream part.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My consideration is: this is not a big enhancement in addition to the existing general allocator hoist logic within this "framework", if we're going to unify all the allocator behavior in a separate extension instead of mixing them, I think we can go with this way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will be hard to extract the downstream logic when the upstream PR got merged and if we would like to rebase. :) Just a suggestion. it is up to you anyway :)
isa<arith::ConstantIndexOp>(ub.getDefiningOp())); | ||
}); | ||
|
||
isStatic &= llvm::all_of(lowerBounds, [](Value &lb) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shall we also check the step? I am also not sure if &=
will short-cut the evaluation of llvm::all_of(...)
if isStatic
is false.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I can add the check for the step, and use early return when the expression return false.
%alloc_0 = memref.alloc() : memref<8xf32> | ||
%1 = scf.for %k = %lb to %ub step %step | ||
iter_args(%iterBuf = %arg0) -> (memref<2xf32>) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to confirm, what are we testing here, to set iter_args
to arg0
? Is it to check if the scheduler can skip complex lifetime?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Neither, the case is simply to demonstrate the scenario of the mixed usage of scf.forall
and scf.for
, the only testing purpose is for those allocators in the case. And BTW, iter_args
still supports memref in addition to tensor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We will not generate loops with memref in iter_args
after bufferization. This feature is for dynamic allocations which will be identified as complex access.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your explanation, the complex access is not the main testing purpose for the case here. And it was originally borrowed from: https://github.com/llvm/llvm-project/blob/main/mlir/test/Dialect/Bufferization/Transforms/buffer-loop-hoisting.mlir#L163-L181.
while (parent) { | ||
if (auto forallOp = dyn_cast<scf::ForallOp>(parent)) { | ||
if (isForallLoopBoundStatic(forallOp)) { | ||
SmallVector<Value> upperBounds = forallOp.getUpperBound(builder); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggest to use getStaticUpperBound instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems getStaticUpperBound()
always return the "std::numeric_limits<int64_t>::min()", so I keep the current impl.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK. That may be we are using constant ops instead of attrs for constant bounds. Another possible way may be to check both MiexedValue and Value being constants.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't make much difference, both ways shall be fine?
llvm::zip(forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(), | ||
forallOp.getMixedStep())) { | ||
std::optional<int64_t> ubConst = getConstantIntValue(ub); | ||
return ubConst.has_value() && isConstantIntValue(lb, 0) && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not support the case with lb != 0
or step != 1
? We may have loop like
scf.forall (%arg7) = (0) to (512) step (32)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not a hard limitation, it's now updated to support this case.
|
||
// Get the total number of threads from the outermost to the current level of | ||
// the parallel loop that the allocation located in. | ||
int64_t numThreads = 1; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The calculation of numThreads
could directly use the upstream util constantTripCount
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the tips, changed to use the util function.
This PR is to track #120
This is another implementation of hoisting thread-local allocator base on the new memref-merge design, so it depends on #44