-
Notifications
You must be signed in to change notification settings - Fork 11.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[flang][cuda] Enable data transfer for descriptors #92804
Conversation
@llvm/pr-subscribers-flang-fir-hlfir Author: Valentin Clement (バレンタイン クレメン) (clementval) ChangesRemove the TODO when data transfer is done with descriptor variables. Full diff: https://github.com/llvm/llvm-project/pull/92804.diff 4 Files Affected:
diff --git a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
index 72157bce4f768..b33aeca590b56 100644
--- a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
+++ b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
@@ -154,13 +154,15 @@ def cuf_DataTransferOp : cuf_Op<"data_transfer", []> {
```
}];
- let arguments = (ins Arg<AnyReferenceLike, "", [MemWrite]>:$src,
- Arg<AnyReferenceLike, "", [MemRead]>:$dst,
+ let arguments = (ins Arg<AnyRefOrBoxType, "", [MemWrite]>:$src,
+ Arg<AnyRefOrBoxType, "", [MemRead]>:$dst,
cuf_DataTransferKindAttr:$transfer_kind);
let assemblyFormat = [{
$src `to` $dst attr-dict `:` type(operands)
}];
+
+ let hasVerifier = 1;
}
def cuf_KernelLaunchOp : cuf_Op<"kernel_launch", [CallOpInterface,
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 7ded9adcd5c2a..8e9ce78119d18 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -3717,8 +3717,17 @@ class FirConverter : public Fortran::lower::AbstractConverter {
hlfir::Entity &lhs, hlfir::Entity &rhs) {
bool lhsIsDevice = Fortran::evaluate::HasCUDAAttrs(assign.lhs);
bool rhsIsDevice = Fortran::evaluate::HasCUDAAttrs(assign.rhs);
- if (rhs.isBoxAddressOrValue() || lhs.isBoxAddressOrValue())
- TODO(loc, "CUDA data transfler with descriptors");
+
+ auto getRefIfLoaded = [](mlir::Value val) -> mlir::Value {
+ if (mlir::isa_and_nonnull<fir::LoadOp>(val.getDefiningOp())) {
+ auto loadOp = mlir::dyn_cast<fir::LoadOp>(val.getDefiningOp());
+ return loadOp.getMemref();
+ }
+ return val;
+ };
+
+ mlir::Value rhsVal = getRefIfLoaded(rhs.getBase());
+ mlir::Value lhsVal = getRefIfLoaded(lhs.getBase());
// device = host
if (lhsIsDevice && !rhsIsDevice) {
@@ -3727,11 +3736,12 @@ class FirConverter : public Fortran::lower::AbstractConverter {
if (!rhs.isVariable()) {
auto associate = hlfir::genAssociateExpr(
loc, builder, rhs, rhs.getType(), ".cuf_host_tmp");
- builder.create<cuf::DataTransferOp>(loc, associate.getBase(), lhs,
+ builder.create<cuf::DataTransferOp>(loc, associate.getBase(), lhsVal,
transferKindAttr);
builder.create<hlfir::EndAssociateOp>(loc, associate);
} else {
- builder.create<cuf::DataTransferOp>(loc, rhs, lhs, transferKindAttr);
+ builder.create<cuf::DataTransferOp>(loc, rhsVal, lhsVal,
+ transferKindAttr);
}
return;
}
@@ -3740,26 +3750,18 @@ class FirConverter : public Fortran::lower::AbstractConverter {
if (!lhsIsDevice && rhsIsDevice) {
auto transferKindAttr = cuf::DataTransferKindAttr::get(
builder.getContext(), cuf::DataTransferKind::DeviceHost);
- if (!rhs.isVariable()) {
- // evaluateRhs loads scalar. Look for the memory reference to be used in
- // the transfer.
- if (mlir::isa_and_nonnull<fir::LoadOp>(rhs.getDefiningOp())) {
- auto loadOp = mlir::dyn_cast<fir::LoadOp>(rhs.getDefiningOp());
- builder.create<cuf::DataTransferOp>(loc, loadOp.getMemref(), lhs,
- transferKindAttr);
- return;
- }
- } else {
- builder.create<cuf::DataTransferOp>(loc, rhs, lhs, transferKindAttr);
- }
+ builder.create<cuf::DataTransferOp>(loc, rhsVal, lhsVal,
+ transferKindAttr);
return;
}
+ // device = device
if (lhsIsDevice && rhsIsDevice) {
assert(rhs.isVariable() && "CUDA Fortran assignment rhs is not legal");
auto transferKindAttr = cuf::DataTransferKindAttr::get(
builder.getContext(), cuf::DataTransferKind::DeviceDevice);
- builder.create<cuf::DataTransferOp>(loc, rhs, lhs, transferKindAttr);
+ builder.create<cuf::DataTransferOp>(loc, rhsVal, lhsVal,
+ transferKindAttr);
return;
}
llvm_unreachable("Unhandled CUDA data transfer");
diff --git a/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp b/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
index 870652c72fab7..b00c374682922 100644
--- a/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
+++ b/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
@@ -89,6 +89,19 @@ mlir::LogicalResult cuf::AllocateOp::verify() {
return mlir::success();
}
+//===----------------------------------------------------------------------===//
+// DataTransferOp
+//===----------------------------------------------------------------------===//
+
+mlir::LogicalResult cuf::DataTransferOp::verify() {
+ mlir::Type srcTy = getSrc().getType();
+ mlir::Type dstTy = getDst().getType();
+ if (fir::isa_ref_type(srcTy) && fir::isa_ref_type(dstTy) ||
+ fir::isa_box_type(srcTy) && fir::isa_box_type(dstTy))
+ return mlir::success();
+ return emitOpError("expect src and dst to be both references or descriptors");
+}
+
//===----------------------------------------------------------------------===//
// DeallocateOp
//===----------------------------------------------------------------------===//
diff --git a/flang/test/Lower/CUDA/cuda-data-transfer.cuf b/flang/test/Lower/CUDA/cuda-data-transfer.cuf
index 084314ed63ecd..e23792e6efc55 100644
--- a/flang/test/Lower/CUDA/cuda-data-transfer.cuf
+++ b/flang/test/Lower/CUDA/cuda-data-transfer.cuf
@@ -159,3 +159,22 @@ end subroutine
! CHECK-LABEL: func.func @_QPsub6
! CHECK: cuf.data_transfer
+
+subroutine sub7(a, b, c)
+ integer, device, allocatable :: a(:), c(:)
+ integer, allocatable :: b(:)
+ b = a
+
+ a = b
+
+ c = a
+end subroutine
+
+! CHECK-LABEL: func.func @_QPsub7(
+! CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "a"}, %[[ARG1:.*]]: !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>> {fir.bindc_name = "b"}, %[[ARG2:.*]]: !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "c"}) {
+! CHECK: %[[A:.*]]:2 = hlfir.declare %[[ARG0]] dummy_scope %{{.*}} {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub7Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.dscope) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
+! CHECK: %[[B:.*]]:2 = hlfir.declare %[[ARG1]] dummy_scope %{{.*}} {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub7Eb"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.dscope) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
+! CHECK: %[[C:.*]]:2 = hlfir.declare %[[ARG2]] dummy_scope %0 {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub7Ec"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.dscope) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
+! CHECK: cuf.data_transfer %[[A]]#0 to %[[B]]#0 {transfer_kind = #cuf.cuda_transfer<device_host>} : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+! CHECK: cuf.data_transfer %[[B]]#0 to %[[A]]#0 {transfer_kind = #cuf.cuda_transfer<host_device>} : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+! CHECK: cuf.data_transfer %[[A]]#0 to %[[C]]#0 {transfer_kind = #cuf.cuda_transfer<device_device>} : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
|
@vzakhari Does it look ok for you with the changes? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, Valentin!
Remove the TODO when data transfer is done with descriptor variables.