forked from OSchip/llvm-project
[mlir][openacc] Add conversion for if operand to scf.if for standalone data operation
This patch convert the if condition on standalone data operation such as acc.update, acc.enter_data and acc.exit_data to a scf.if with the operation in the if region. It removes the operation when the if condition is constant and false. It removes the the condition if it is contant and true. Conversion to scf.if is done in order to use the translation to LLVM IR dialect out of the box. Not sure this is the best approach or we should perform this during the translation from OpenACC to LLVM IR dialect. Any thoughts welcome. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D103325
This commit is contained in:
parent
aa4e6a609a
commit
fb5b590b5e
|
@ -0,0 +1,28 @@
|
|||
//===- ConvertOpenACCToSCF.h - OpenACC conversion pass entrypoint ---------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
#ifndef MLIR_CONVERSION_OPENACCTOSCF_CONVERTOPENACCTOSCF_H
|
||||
#define MLIR_CONVERSION_OPENACCTOSCF_CONVERTOPENACCTOSCF_H
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace mlir {
|
||||
class ModuleOp;
|
||||
template <typename T>
|
||||
class OperationPass;
|
||||
class RewritePatternSet;
|
||||
|
||||
/// Collect the patterns to convert from the OpenACC dialect to OpenACC with
|
||||
/// SCF dialect.
|
||||
void populateOpenACCToSCFConversionPatterns(RewritePatternSet &patterns);
|
||||
|
||||
/// Create a pass to convert the OpenACC dialect into the LLVMIR dialect.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createConvertOpenACCToSCFPass();
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_CONVERSION_OPENACCTOSCF_CONVERTOPENACCTOSCF_H
|
|
@ -23,6 +23,7 @@
|
|||
#include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
|
||||
#include "mlir/Conversion/MathToLibm/MathToLibm.h"
|
||||
#include "mlir/Conversion/OpenACCToLLVM/ConvertOpenACCToLLVM.h"
|
||||
#include "mlir/Conversion/OpenACCToSCF/ConvertOpenACCToSCF.h"
|
||||
#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
|
||||
#include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
|
||||
#include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h"
|
||||
|
|
|
@ -255,6 +255,16 @@ def ConvertMathToLibm : Pass<"convert-math-to-libm", "ModuleOp"> {
|
|||
let dependentDialects = ["StandardOpsDialect", "vector::VectorDialect"];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// OpenACCToSCF
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def ConvertOpenACCToSCF : Pass<"convert-openacc-to-scf", "ModuleOp"> {
|
||||
let summary = "Convert the OpenACC ops to OpenACC with SCF dialect";
|
||||
let constructor = "mlir::createConvertOpenACCToSCFPass()";
|
||||
let dependentDialects = ["scf::SCFDialect", "acc::OpenACCDialect"];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// OpenACCToLLVM
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -503,13 +503,13 @@ def OpenACC_UpdateOp : OpenACC_Op<"update", [AttrSizedOperandSegments]> {
|
|||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins Optional<IntOrIndex>:$asyncOperand,
|
||||
let arguments = (ins Optional<I1>:$ifCond,
|
||||
Optional<IntOrIndex>:$asyncOperand,
|
||||
Optional<IntOrIndex>:$waitDevnum,
|
||||
Variadic<IntOrIndex>:$waitOperands,
|
||||
UnitAttr:$async,
|
||||
UnitAttr:$wait,
|
||||
Variadic<IntOrIndex>:$deviceTypeOperands,
|
||||
Optional<I1>:$ifCond,
|
||||
Variadic<AnyType>:$hostOperands,
|
||||
Variadic<AnyType>:$deviceOperands,
|
||||
UnitAttr:$ifPresent);
|
||||
|
|
|
@ -12,6 +12,7 @@ add_subdirectory(LinalgToSPIRV)
|
|||
add_subdirectory(LinalgToStandard)
|
||||
add_subdirectory(MathToLibm)
|
||||
add_subdirectory(OpenACCToLLVM)
|
||||
add_subdirectory(OpenACCToSCF)
|
||||
add_subdirectory(OpenMPToLLVM)
|
||||
add_subdirectory(PDLToPDLInterp)
|
||||
add_subdirectory(SCFToGPU)
|
||||
|
|
|
@ -0,0 +1,15 @@
|
|||
add_mlir_conversion_library(MLIROpenACCToSCF
|
||||
OpenACCToSCF.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/OpenACCToSCF
|
||||
|
||||
DEPENDS
|
||||
MLIRConversionPassIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIROpenACC
|
||||
MLIRTransforms
|
||||
MLIRSCF
|
||||
)
|
|
@ -0,0 +1,90 @@
|
|||
//===- OpenACCToSCF.cpp - OpenACC condition to SCF if conversion ----------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "../PassDetail.h"
|
||||
#include "mlir/Conversion/OpenACCToSCF/ConvertOpenACCToSCF.h"
|
||||
#include "mlir/Dialect/OpenACC/OpenACC.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Conversion patterns
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
/// Pattern to transform the `ifCond` on operation without region into a scf.if
|
||||
/// and move the operation into the `then` region.
|
||||
template <typename OpTy>
|
||||
class ExpandIfCondition : public OpRewritePattern<OpTy> {
|
||||
using OpRewritePattern<OpTy>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(OpTy op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Early exit if there is no condition.
|
||||
if (!op.ifCond())
|
||||
return success();
|
||||
|
||||
// Condition is not a constant.
|
||||
if (!op.ifCond().template getDefiningOp<ConstantOp>()) {
|
||||
auto ifOp = rewriter.create<scf::IfOp>(op.getLoc(), TypeRange(),
|
||||
op.ifCond(), false);
|
||||
rewriter.updateRootInPlace(op, [&]() { op.ifCondMutable().erase(0); });
|
||||
auto thenBodyBuilder = ifOp.getThenBodyBuilder();
|
||||
thenBodyBuilder.setListener(rewriter.getListener());
|
||||
thenBodyBuilder.clone(*op.getOperation());
|
||||
rewriter.eraseOp(op);
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void mlir::populateOpenACCToSCFConversionPatterns(RewritePatternSet &patterns) {
|
||||
patterns.add<ExpandIfCondition<acc::EnterDataOp>>(patterns.getContext());
|
||||
patterns.add<ExpandIfCondition<acc::ExitDataOp>>(patterns.getContext());
|
||||
patterns.add<ExpandIfCondition<acc::UpdateOp>>(patterns.getContext());
|
||||
}
|
||||
|
||||
namespace {
|
||||
struct ConvertOpenACCToSCFPass
|
||||
: public ConvertOpenACCToSCFBase<ConvertOpenACCToSCFPass> {
|
||||
void runOnOperation() override;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void ConvertOpenACCToSCFPass::runOnOperation() {
|
||||
auto op = getOperation();
|
||||
auto *context = op.getContext();
|
||||
|
||||
RewritePatternSet patterns(context);
|
||||
ConversionTarget target(*context);
|
||||
populateOpenACCToSCFConversionPatterns(patterns);
|
||||
|
||||
target.addLegalDialect<scf::SCFDialect>();
|
||||
target.addLegalDialect<acc::OpenACCDialect>();
|
||||
|
||||
target.addDynamicallyLegalOp<acc::EnterDataOp>(
|
||||
[](acc::EnterDataOp op) { return !op.ifCond(); });
|
||||
|
||||
target.addDynamicallyLegalOp<acc::ExitDataOp>(
|
||||
[](acc::ExitDataOp op) { return !op.ifCond(); });
|
||||
|
||||
target.addDynamicallyLegalOp<acc::UpdateOp>(
|
||||
[](acc::UpdateOp op) { return !op.ifCond(); });
|
||||
|
||||
if (failed(applyPartialConversion(op, target, std::move(patterns))))
|
||||
signalPassFailure();
|
||||
}
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertOpenACCToSCFPass() {
|
||||
return std::make_unique<ConvertOpenACCToSCFPass>();
|
||||
}
|
|
@ -19,6 +19,10 @@ class StandardOpsDialect;
|
|||
template <typename ConcreteDialect>
|
||||
void registerDialect(DialectRegistry ®istry);
|
||||
|
||||
namespace acc {
|
||||
class OpenACCDialect;
|
||||
} // end namespace acc
|
||||
|
||||
namespace complex {
|
||||
class ComplexDialect;
|
||||
} // end namespace complex
|
||||
|
|
|
@ -0,0 +1,35 @@
|
|||
// RUN: mlir-opt %s -convert-openacc-to-scf -split-input-file | FileCheck %s
|
||||
|
||||
func @testenterdataop(%a: memref<10xf32>, %ifCond: i1) -> () {
|
||||
acc.enter_data if(%ifCond) create(%a: memref<10xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: func @testenterdataop(%{{.*}}: memref<10xf32>, [[IFCOND:%.*]]: i1)
|
||||
// CHECK: scf.if [[IFCOND]] {
|
||||
// CHECK-NEXT: acc.enter_data create(%{{.*}} : memref<10xf32>)
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// -----
|
||||
|
||||
func @testexitdataop(%a: memref<10xf32>, %ifCond: i1) -> () {
|
||||
acc.exit_data if(%ifCond) delete(%a: memref<10xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: func @testexitdataop(%{{.*}}: memref<10xf32>, [[IFCOND:%.*]]: i1)
|
||||
// CHECK: scf.if [[IFCOND]] {
|
||||
// CHECK-NEXT: acc.exit_data delete(%{{.*}} : memref<10xf32>)
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// -----
|
||||
|
||||
func @testupdateop(%a: memref<10xf32>, %ifCond: i1) -> () {
|
||||
acc.update if(%ifCond) host(%a: memref<10xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: func @testupdateop(%{{.*}}: memref<10xf32>, [[IFCOND:%.*]]: i1)
|
||||
// CHECK: scf.if [[IFCOND]] {
|
||||
// CHECK-NEXT: acc.update host(%{{.*}} : memref<10xf32>)
|
||||
// CHECK-NEXT: }
|
Loading…
Reference in New Issue