[mlir][openacc] Add conversion for if operand to scf.if for standalone data operation

This patch convert the if condition on standalone data operation such as acc.update,
acc.enter_data and acc.exit_data to a scf.if with the operation in the if region.
It removes the operation when the if condition is constant and false. It removes the
the condition if it is contant and true.

Conversion to scf.if is done in order to use the translation to LLVM IR dialect out of the box.
Not sure this is the best approach or we should perform this during the translation from OpenACC
to LLVM IR dialect. Any thoughts welcome.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D103325
This commit is contained in:
Valentin Clement 2021-06-07 12:09:25 -04:00 committed by clementval
parent aa4e6a609a
commit fb5b590b5e
9 changed files with 186 additions and 2 deletions

View File

@ -0,0 +1,28 @@
//===- ConvertOpenACCToSCF.h - OpenACC conversion pass entrypoint ---------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_CONVERSION_OPENACCTOSCF_CONVERTOPENACCTOSCF_H
#define MLIR_CONVERSION_OPENACCTOSCF_CONVERTOPENACCTOSCF_H
#include <memory>
namespace mlir {
class ModuleOp;
template <typename T>
class OperationPass;
class RewritePatternSet;
/// Collect the patterns to convert from the OpenACC dialect to OpenACC with
/// SCF dialect.
void populateOpenACCToSCFConversionPatterns(RewritePatternSet &patterns);
/// Create a pass to convert the OpenACC dialect into the LLVMIR dialect.
std::unique_ptr<OperationPass<ModuleOp>> createConvertOpenACCToSCFPass();
} // namespace mlir
#endif // MLIR_CONVERSION_OPENACCTOSCF_CONVERTOPENACCTOSCF_H

View File

@ -23,6 +23,7 @@
#include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
#include "mlir/Conversion/MathToLibm/MathToLibm.h"
#include "mlir/Conversion/OpenACCToLLVM/ConvertOpenACCToLLVM.h"
#include "mlir/Conversion/OpenACCToSCF/ConvertOpenACCToSCF.h"
#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
#include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
#include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h"

View File

@ -255,6 +255,16 @@ def ConvertMathToLibm : Pass<"convert-math-to-libm", "ModuleOp"> {
let dependentDialects = ["StandardOpsDialect", "vector::VectorDialect"];
}
//===----------------------------------------------------------------------===//
// OpenACCToSCF
//===----------------------------------------------------------------------===//
def ConvertOpenACCToSCF : Pass<"convert-openacc-to-scf", "ModuleOp"> {
let summary = "Convert the OpenACC ops to OpenACC with SCF dialect";
let constructor = "mlir::createConvertOpenACCToSCFPass()";
let dependentDialects = ["scf::SCFDialect", "acc::OpenACCDialect"];
}
//===----------------------------------------------------------------------===//
// OpenACCToLLVM
//===----------------------------------------------------------------------===//

View File

@ -503,13 +503,13 @@ def OpenACC_UpdateOp : OpenACC_Op<"update", [AttrSizedOperandSegments]> {
```
}];
let arguments = (ins Optional<IntOrIndex>:$asyncOperand,
let arguments = (ins Optional<I1>:$ifCond,
Optional<IntOrIndex>:$asyncOperand,
Optional<IntOrIndex>:$waitDevnum,
Variadic<IntOrIndex>:$waitOperands,
UnitAttr:$async,
UnitAttr:$wait,
Variadic<IntOrIndex>:$deviceTypeOperands,
Optional<I1>:$ifCond,
Variadic<AnyType>:$hostOperands,
Variadic<AnyType>:$deviceOperands,
UnitAttr:$ifPresent);

View File

@ -12,6 +12,7 @@ add_subdirectory(LinalgToSPIRV)
add_subdirectory(LinalgToStandard)
add_subdirectory(MathToLibm)
add_subdirectory(OpenACCToLLVM)
add_subdirectory(OpenACCToSCF)
add_subdirectory(OpenMPToLLVM)
add_subdirectory(PDLToPDLInterp)
add_subdirectory(SCFToGPU)

View File

@ -0,0 +1,15 @@
add_mlir_conversion_library(MLIROpenACCToSCF
OpenACCToSCF.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/OpenACCToSCF
DEPENDS
MLIRConversionPassIncGen
LINK_LIBS PUBLIC
MLIRIR
MLIROpenACC
MLIRTransforms
MLIRSCF
)

View File

@ -0,0 +1,90 @@
//===- OpenACCToSCF.cpp - OpenACC condition to SCF if conversion ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "../PassDetail.h"
#include "mlir/Conversion/OpenACCToSCF/ConvertOpenACCToSCF.h"
#include "mlir/Dialect/OpenACC/OpenACC.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
//===----------------------------------------------------------------------===//
// Conversion patterns
//===----------------------------------------------------------------------===//
namespace {
/// Pattern to transform the `ifCond` on operation without region into a scf.if
/// and move the operation into the `then` region.
template <typename OpTy>
class ExpandIfCondition : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
// Early exit if there is no condition.
if (!op.ifCond())
return success();
// Condition is not a constant.
if (!op.ifCond().template getDefiningOp<ConstantOp>()) {
auto ifOp = rewriter.create<scf::IfOp>(op.getLoc(), TypeRange(),
op.ifCond(), false);
rewriter.updateRootInPlace(op, [&]() { op.ifCondMutable().erase(0); });
auto thenBodyBuilder = ifOp.getThenBodyBuilder();
thenBodyBuilder.setListener(rewriter.getListener());
thenBodyBuilder.clone(*op.getOperation());
rewriter.eraseOp(op);
}
return success();
}
};
} // namespace
void mlir::populateOpenACCToSCFConversionPatterns(RewritePatternSet &patterns) {
patterns.add<ExpandIfCondition<acc::EnterDataOp>>(patterns.getContext());
patterns.add<ExpandIfCondition<acc::ExitDataOp>>(patterns.getContext());
patterns.add<ExpandIfCondition<acc::UpdateOp>>(patterns.getContext());
}
namespace {
struct ConvertOpenACCToSCFPass
: public ConvertOpenACCToSCFBase<ConvertOpenACCToSCFPass> {
void runOnOperation() override;
};
} // namespace
void ConvertOpenACCToSCFPass::runOnOperation() {
auto op = getOperation();
auto *context = op.getContext();
RewritePatternSet patterns(context);
ConversionTarget target(*context);
populateOpenACCToSCFConversionPatterns(patterns);
target.addLegalDialect<scf::SCFDialect>();
target.addLegalDialect<acc::OpenACCDialect>();
target.addDynamicallyLegalOp<acc::EnterDataOp>(
[](acc::EnterDataOp op) { return !op.ifCond(); });
target.addDynamicallyLegalOp<acc::ExitDataOp>(
[](acc::ExitDataOp op) { return !op.ifCond(); });
target.addDynamicallyLegalOp<acc::UpdateOp>(
[](acc::UpdateOp op) { return !op.ifCond(); });
if (failed(applyPartialConversion(op, target, std::move(patterns))))
signalPassFailure();
}
std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertOpenACCToSCFPass() {
return std::make_unique<ConvertOpenACCToSCFPass>();
}

View File

@ -19,6 +19,10 @@ class StandardOpsDialect;
template <typename ConcreteDialect>
void registerDialect(DialectRegistry &registry);
namespace acc {
class OpenACCDialect;
} // end namespace acc
namespace complex {
class ComplexDialect;
} // end namespace complex

View File

@ -0,0 +1,35 @@
// RUN: mlir-opt %s -convert-openacc-to-scf -split-input-file | FileCheck %s
func @testenterdataop(%a: memref<10xf32>, %ifCond: i1) -> () {
acc.enter_data if(%ifCond) create(%a: memref<10xf32>)
return
}
// CHECK: func @testenterdataop(%{{.*}}: memref<10xf32>, [[IFCOND:%.*]]: i1)
// CHECK: scf.if [[IFCOND]] {
// CHECK-NEXT: acc.enter_data create(%{{.*}} : memref<10xf32>)
// CHECK-NEXT: }
// -----
func @testexitdataop(%a: memref<10xf32>, %ifCond: i1) -> () {
acc.exit_data if(%ifCond) delete(%a: memref<10xf32>)
return
}
// CHECK: func @testexitdataop(%{{.*}}: memref<10xf32>, [[IFCOND:%.*]]: i1)
// CHECK: scf.if [[IFCOND]] {
// CHECK-NEXT: acc.exit_data delete(%{{.*}} : memref<10xf32>)
// CHECK-NEXT: }
// -----
func @testupdateop(%a: memref<10xf32>, %ifCond: i1) -> () {
acc.update if(%ifCond) host(%a: memref<10xf32>)
return
}
// CHECK: func @testupdateop(%{{.*}}: memref<10xf32>, [[IFCOND:%.*]]: i1)
// CHECK: scf.if [[IFCOND]] {
// CHECK-NEXT: acc.update host(%{{.*}} : memref<10xf32>)
// CHECK-NEXT: }