diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td index 2685cdc4efa3..37867e6e1998 100644 --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -139,6 +139,12 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel", static StringRef getAttachKeyword() { return "attach"; } static StringRef getPrivateKeyword() { return "private"; } static StringRef getFirstPrivateKeyword() { return "firstprivate"; } + + /// The number of data operands. + unsigned getNumDataOperands(); + + /// The i-th data operand passed. + Value getDataOperand(unsigned i); }]; let verifier = ?; diff --git a/mlir/lib/Conversion/OpenACCToLLVM/OpenACCToLLVM.cpp b/mlir/lib/Conversion/OpenACCToLLVM/OpenACCToLLVM.cpp index eb34b0e7c9dc..fce5ed0d1295 100644 --- a/mlir/lib/Conversion/OpenACCToLLVM/OpenACCToLLVM.cpp +++ b/mlir/lib/Conversion/OpenACCToLLVM/OpenACCToLLVM.cpp @@ -140,6 +140,7 @@ void mlir::populateOpenACCToLLVMConversionPatterns( patterns.add>(converter); patterns.add>(converter); patterns.add>(converter); + patterns.add>(converter); patterns.add>(converter); } @@ -201,6 +202,24 @@ void ConvertOpenACCToLLVMPass::runOnOperation() { allDataOperandsAreConverted(op.detachOperands()); }); + target.addDynamicallyLegalOp( + [allDataOperandsAreConverted](acc::ParallelOp op) { + return allDataOperandsAreConverted(op.reductionOperands()) && + allDataOperandsAreConverted(op.copyOperands()) && + allDataOperandsAreConverted(op.copyinOperands()) && + allDataOperandsAreConverted(op.copyinReadonlyOperands()) && + allDataOperandsAreConverted(op.copyoutOperands()) && + allDataOperandsAreConverted(op.copyoutZeroOperands()) && + allDataOperandsAreConverted(op.createOperands()) && + allDataOperandsAreConverted(op.createZeroOperands()) && + allDataOperandsAreConverted(op.noCreateOperands()) && + allDataOperandsAreConverted(op.presentOperands()) && + allDataOperandsAreConverted(op.devicePtrOperands()) && + allDataOperandsAreConverted(op.attachOperands()) && + allDataOperandsAreConverted(op.gangPrivateOperands()) && + allDataOperandsAreConverted(op.gangFirstPrivateOperands()); + }); + target.addDynamicallyLegalOp( [allDataOperandsAreConverted](acc::UpdateOp op) { return allDataOperandsAreConverted(op.hostOperands()) && diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp index 974e67ff83cc..b1cc4e796120 100644 --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -448,6 +448,26 @@ static void print(OpAsmPrinter &printer, ParallelOp &op) { op->getAttrs(), ParallelOp::getOperandSegmentSizeAttr()); } +unsigned ParallelOp::getNumDataOperands() { + return reductionOperands().size() + copyOperands().size() + + copyinOperands().size() + copyinReadonlyOperands().size() + + copyoutOperands().size() + copyoutZeroOperands().size() + + createOperands().size() + createZeroOperands().size() + + noCreateOperands().size() + presentOperands().size() + + devicePtrOperands().size() + attachOperands().size() + + gangPrivateOperands().size() + gangFirstPrivateOperands().size(); +} + +Value ParallelOp::getDataOperand(unsigned i) { + unsigned numOptional = async() ? 1 : 0; + numOptional += numGangs() ? 1 : 0; + numOptional += numWorkers() ? 1 : 0; + numOptional += vectorLength() ? 1 : 0; + numOptional += ifCond() ? 1 : 0; + numOptional += selfCond() ? 1 : 0; + return getOperand(waitOperands().size() + numOptional + i); +} + //===----------------------------------------------------------------------===// // LoopOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/OpenACCToLLVM/convert-data-operands-to-llvmir.mlir b/mlir/test/Conversion/OpenACCToLLVM/convert-data-operands-to-llvmir.mlir index b25f80b165be..0c4d9a9210a8 100644 --- a/mlir/test/Conversion/OpenACCToLLVM/convert-data-operands-to-llvmir.mlir +++ b/mlir/test/Conversion/OpenACCToLLVM/convert-data-operands-to-llvmir.mlir @@ -159,3 +159,65 @@ func @testdataregion(%a: memref<10xf32>, %b: memref<10xf32>) -> () { } // CHECK: acc.data present(%{{.*}}, %{{.*}} : !llvm.struct<"openacc_data", (struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, ptr, i64)>, !llvm.struct<"openacc_data.1", (struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, ptr, i64)>) + +// ----- + +func @testparallelop(%a: memref<10xf32>, %b: memref<10xf32>) -> () { + acc.parallel copy(%b : memref<10xf32>) copyout(%a : memref<10xf32>) { + } + return +} + +// CHECK: acc.parallel copy(%{{.*}}: !llvm.struct<"openacc_data", (struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, ptr, i64)>) copyout(%{{.*}}: !llvm.struct<"openacc_data.1", (struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, ptr, i64)>) + +// ----- + +func @testparallelop(%a: !llvm.ptr, %b: memref<10xf32>, %c: !llvm.ptr) -> () { + acc.parallel copyin(%b : memref<10xf32>) deviceptr(%c: !llvm.ptr) attach(%a : !llvm.ptr) { + } + return +} + +// CHECK: acc.parallel copyin(%{{.*}}: !llvm.struct<"openacc_data", (struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, ptr, i64)>) deviceptr(%{{.*}}: !llvm.ptr) attach(%{{.*}}: !llvm.ptr) + +// ----- + +func @testparallelop(%a: memref<10xf32>, %b: memref<10xf32>) -> () { + %ifCond = constant true + acc.parallel if(%ifCond) copyin_readonly(%b : memref<10xf32>) copyout_zero(%a : memref<10xf32>) { + } + return +} + +// CHECK: acc.parallel if(%{{.*}}) copyin_readonly(%{{.*}}: !llvm.struct<"openacc_data", (struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, ptr, i64)>) copyout_zero(%{{.*}}: !llvm.struct<"openacc_data.1", (struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, ptr, i64)>) + +// ----- + +func @testparallelop(%a: !llvm.ptr, %b: memref<10xf32>, %c: !llvm.ptr) -> () { + acc.parallel create(%b : memref<10xf32>) create_zero(%c: !llvm.ptr) no_create(%a : !llvm.ptr) { + } + return +} + +// CHECK: acc.parallel create(%{{.*}}: !llvm.struct<"openacc_data", (struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, ptr, i64)>) create_zero(%{{.*}}: !llvm.ptr) no_create(%{{.*}}: !llvm.ptr) + +// ----- + +func @testparallelop(%a: memref<10xf32>, %b: memref<10xf32>) -> () { + acc.parallel present(%a: memref<10xf32>, %b: memref<10xf32>) { + } + return +} + +// CHECK: acc.parallel present(%{{.*}}: !llvm.struct<"openacc_data", (struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, ptr, i64)>, %{{.*}}: !llvm.struct<"openacc_data.1", (struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, ptr, i64)>) + +// ----- + +func @testparallelop(%i: i64, %a: memref<10xf32>, %b: memref<10xf32>) -> () { + acc.parallel num_gangs(%i: i64) present(%a: memref<10xf32>, %b: memref<10xf32>) { + } attributes {async} + return +} + +// CHECK: acc.parallel num_gangs(%{{.*}}: i64) present(%{{.*}}: !llvm.struct<"openacc_data", (struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, ptr, i64)>, %{{.*}}: !llvm.struct<"openacc_data.1", (struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, ptr, i64)>) +// CHECK-NEXT: } attributes {async}