Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flang][cuda] Enable data transfer for descriptors #92804

Merged
merged 4 commits into from
May 21, 2024

Conversation

clementval
Copy link
Contributor

Remove the TODO when data transfer is done with descriptor variables.

@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir labels May 20, 2024
@clementval clementval changed the title [flang][cuda] Enable data transfer for descriptor [flang][cuda] Enable data transfer for descriptors May 20, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented May 20, 2024

@llvm/pr-subscribers-flang-fir-hlfir

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

Remove the TODO when data transfer is done with descriptor variables.


Full diff: https://github.com/llvm/llvm-project/pull/92804.diff

4 Files Affected:

  • (modified) flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td (+4-2)
  • (modified) flang/lib/Lower/Bridge.cpp (+19-17)
  • (modified) flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp (+13)
  • (modified) flang/test/Lower/CUDA/cuda-data-transfer.cuf (+19)
diff --git a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
index 72157bce4f768..b33aeca590b56 100644
--- a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
+++ b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
@@ -154,13 +154,15 @@ def cuf_DataTransferOp : cuf_Op<"data_transfer", []> {
     ```
   }];
 
-  let arguments = (ins Arg<AnyReferenceLike, "", [MemWrite]>:$src,
-                       Arg<AnyReferenceLike, "", [MemRead]>:$dst,
+  let arguments = (ins Arg<AnyRefOrBoxType, "", [MemWrite]>:$src,
+                       Arg<AnyRefOrBoxType, "", [MemRead]>:$dst,
                        cuf_DataTransferKindAttr:$transfer_kind);
 
   let assemblyFormat = [{
     $src `to` $dst attr-dict `:` type(operands)
   }];
+
+  let hasVerifier = 1;
 }
 
 def cuf_KernelLaunchOp : cuf_Op<"kernel_launch", [CallOpInterface,
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 7ded9adcd5c2a..8e9ce78119d18 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -3717,8 +3717,17 @@ class FirConverter : public Fortran::lower::AbstractConverter {
                            hlfir::Entity &lhs, hlfir::Entity &rhs) {
     bool lhsIsDevice = Fortran::evaluate::HasCUDAAttrs(assign.lhs);
     bool rhsIsDevice = Fortran::evaluate::HasCUDAAttrs(assign.rhs);
-    if (rhs.isBoxAddressOrValue() || lhs.isBoxAddressOrValue())
-      TODO(loc, "CUDA data transfler with descriptors");
+
+    auto getRefIfLoaded = [](mlir::Value val) -> mlir::Value {
+      if (mlir::isa_and_nonnull<fir::LoadOp>(val.getDefiningOp())) {
+        auto loadOp = mlir::dyn_cast<fir::LoadOp>(val.getDefiningOp());
+        return loadOp.getMemref();
+      }
+      return val;
+    };
+
+    mlir::Value rhsVal = getRefIfLoaded(rhs.getBase());
+    mlir::Value lhsVal = getRefIfLoaded(lhs.getBase());
 
     // device = host
     if (lhsIsDevice && !rhsIsDevice) {
@@ -3727,11 +3736,12 @@ class FirConverter : public Fortran::lower::AbstractConverter {
       if (!rhs.isVariable()) {
         auto associate = hlfir::genAssociateExpr(
             loc, builder, rhs, rhs.getType(), ".cuf_host_tmp");
-        builder.create<cuf::DataTransferOp>(loc, associate.getBase(), lhs,
+        builder.create<cuf::DataTransferOp>(loc, associate.getBase(), lhsVal,
                                             transferKindAttr);
         builder.create<hlfir::EndAssociateOp>(loc, associate);
       } else {
-        builder.create<cuf::DataTransferOp>(loc, rhs, lhs, transferKindAttr);
+        builder.create<cuf::DataTransferOp>(loc, rhsVal, lhsVal,
+                                            transferKindAttr);
       }
       return;
     }
@@ -3740,26 +3750,18 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     if (!lhsIsDevice && rhsIsDevice) {
       auto transferKindAttr = cuf::DataTransferKindAttr::get(
           builder.getContext(), cuf::DataTransferKind::DeviceHost);
-      if (!rhs.isVariable()) {
-        // evaluateRhs loads scalar. Look for the memory reference to be used in
-        // the transfer.
-        if (mlir::isa_and_nonnull<fir::LoadOp>(rhs.getDefiningOp())) {
-          auto loadOp = mlir::dyn_cast<fir::LoadOp>(rhs.getDefiningOp());
-          builder.create<cuf::DataTransferOp>(loc, loadOp.getMemref(), lhs,
-                                              transferKindAttr);
-          return;
-        }
-      } else {
-        builder.create<cuf::DataTransferOp>(loc, rhs, lhs, transferKindAttr);
-      }
+      builder.create<cuf::DataTransferOp>(loc, rhsVal, lhsVal,
+                                          transferKindAttr);
       return;
     }
 
+    // device = device
     if (lhsIsDevice && rhsIsDevice) {
       assert(rhs.isVariable() && "CUDA Fortran assignment rhs is not legal");
       auto transferKindAttr = cuf::DataTransferKindAttr::get(
           builder.getContext(), cuf::DataTransferKind::DeviceDevice);
-      builder.create<cuf::DataTransferOp>(loc, rhs, lhs, transferKindAttr);
+      builder.create<cuf::DataTransferOp>(loc, rhsVal, lhsVal,
+                                          transferKindAttr);
       return;
     }
     llvm_unreachable("Unhandled CUDA data transfer");
diff --git a/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp b/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
index 870652c72fab7..b00c374682922 100644
--- a/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
+++ b/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
@@ -89,6 +89,19 @@ mlir::LogicalResult cuf::AllocateOp::verify() {
   return mlir::success();
 }
 
+//===----------------------------------------------------------------------===//
+// DataTransferOp
+//===----------------------------------------------------------------------===//
+
+mlir::LogicalResult cuf::DataTransferOp::verify() {
+  mlir::Type srcTy = getSrc().getType();
+  mlir::Type dstTy = getDst().getType();
+  if (fir::isa_ref_type(srcTy) && fir::isa_ref_type(dstTy) ||
+      fir::isa_box_type(srcTy) && fir::isa_box_type(dstTy))
+    return mlir::success();
+  return emitOpError("expect src and dst to be both references or descriptors");
+}
+
 //===----------------------------------------------------------------------===//
 // DeallocateOp
 //===----------------------------------------------------------------------===//
diff --git a/flang/test/Lower/CUDA/cuda-data-transfer.cuf b/flang/test/Lower/CUDA/cuda-data-transfer.cuf
index 084314ed63ecd..e23792e6efc55 100644
--- a/flang/test/Lower/CUDA/cuda-data-transfer.cuf
+++ b/flang/test/Lower/CUDA/cuda-data-transfer.cuf
@@ -159,3 +159,22 @@ end subroutine
 
 ! CHECK-LABEL: func.func @_QPsub6
 ! CHECK: cuf.data_transfer
+
+subroutine sub7(a, b, c)
+  integer, device, allocatable :: a(:), c(:)
+  integer, allocatable :: b(:)
+  b = a
+
+  a = b
+
+  c = a
+end subroutine
+
+! CHECK-LABEL: func.func @_QPsub7(
+! CHECK-SAME:  %[[ARG0:.*]]: !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "a"}, %[[ARG1:.*]]: !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>> {fir.bindc_name = "b"}, %[[ARG2:.*]]: !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "c"}) {
+! CHECK: %[[A:.*]]:2 = hlfir.declare %[[ARG0]] dummy_scope %{{.*}} {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub7Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.dscope) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
+! CHECK: %[[B:.*]]:2 = hlfir.declare %[[ARG1]] dummy_scope %{{.*}} {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub7Eb"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.dscope) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
+! CHECK: %[[C:.*]]:2 = hlfir.declare %[[ARG2]] dummy_scope %0 {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub7Ec"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.dscope) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
+! CHECK: cuf.data_transfer %[[A]]#0 to %[[B]]#0 {transfer_kind = #cuf.cuda_transfer<device_host>} : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+! CHECK: cuf.data_transfer %[[B]]#0 to %[[A]]#0 {transfer_kind = #cuf.cuda_transfer<host_device>} : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+! CHECK: cuf.data_transfer %[[A]]#0 to %[[C]]#0 {transfer_kind = #cuf.cuda_transfer<device_device>} : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>

flang/lib/Lower/Bridge.cpp Outdated Show resolved Hide resolved
flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td Outdated Show resolved Hide resolved
@clementval
Copy link
Contributor Author

@vzakhari Does it look ok for you with the changes?

Copy link
Contributor

@vzakhari vzakhari left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, Valentin!

@clementval clementval merged commit 1fc3ce1 into llvm:main May 21, 2024
3 of 4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants