[mlir] Support WsLoopOp in OpenMP to LLVM dialect conversion

It is a simple conversion that only requires to change the region argument
types, generalize it from ParallelOp.

Reviewed By: kiranchandramohan

Differential Revision: https://reviews.llvm.org/D91989
This commit is contained in:
Alex Zinenko 2020-11-23 20:45:30 +01:00
parent 0a20660c8f
commit f7d033f4d8
2 changed files with 35 additions and 10 deletions

View File

@ -16,18 +16,23 @@
using namespace mlir; using namespace mlir;
namespace { namespace {
struct ParallelOpConversion : public ConvertToLLVMPattern { /// A pattern that converts the region arguments in a single-region OpenMP
explicit ParallelOpConversion(MLIRContext *context, /// operation to the LLVM dialect. The body of the region is not modified and is
LLVMTypeConverter &typeConverter) /// expected to either be processed by the conversion infrastructure or already
: ConvertToLLVMPattern(omp::ParallelOp::getOperationName(), context, /// contain ops compatible with LLVM dialect types.
template <typename OpType>
struct RegionOpConversion : public ConvertToLLVMPattern {
explicit RegionOpConversion(MLIRContext *context,
LLVMTypeConverter &typeConverter)
: ConvertToLLVMPattern(OpType::getOperationName(), context,
typeConverter) {} typeConverter) {}
LogicalResult LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands, matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
auto curOp = cast<omp::ParallelOp>(op); auto curOp = cast<OpType>(op);
auto newOp = rewriter.create<omp::ParallelOp>(curOp.getLoc(), TypeRange(), auto newOp = rewriter.create<OpType>(curOp.getLoc(), TypeRange(), operands,
operands, curOp.getAttrs()); curOp.getAttrs());
rewriter.inlineRegionBefore(curOp.region(), newOp.region(), rewriter.inlineRegionBefore(curOp.region(), newOp.region(),
newOp.region().end()); newOp.region().end());
if (failed(rewriter.convertRegionTypes(&newOp.region(), typeConverter))) if (failed(rewriter.convertRegionTypes(&newOp.region(), typeConverter)))
@ -42,7 +47,8 @@ struct ParallelOpConversion : public ConvertToLLVMPattern {
void mlir::populateOpenMPToLLVMConversionPatterns( void mlir::populateOpenMPToLLVMConversionPatterns(
MLIRContext *context, LLVMTypeConverter &converter, MLIRContext *context, LLVMTypeConverter &converter,
OwningRewritePatternList &patterns) { OwningRewritePatternList &patterns) {
patterns.insert<ParallelOpConversion>(context, converter); patterns.insert<RegionOpConversion<omp::ParallelOp>,
RegionOpConversion<omp::WsLoopOp>>(context, converter);
} }
namespace { namespace {
@ -63,8 +69,8 @@ void ConvertOpenMPToLLVMPass::runOnOperation() {
populateOpenMPToLLVMConversionPatterns(context, converter, patterns); populateOpenMPToLLVMConversionPatterns(context, converter, patterns);
LLVMConversionTarget target(getContext()); LLVMConversionTarget target(getContext());
target.addDynamicallyLegalOp<omp::ParallelOp>( target.addDynamicallyLegalOp<omp::ParallelOp, omp::WsLoopOp>(
[&](omp::ParallelOp op) { return converter.isLegal(&op.getRegion()); }); [&](Operation *op) { return converter.isLegal(&op->getRegion(0)); });
target.addLegalOp<omp::TerminatorOp, omp::TaskyieldOp, omp::FlushOp, target.addLegalOp<omp::TerminatorOp, omp::TaskyieldOp, omp::FlushOp,
omp::BarrierOp, omp::TaskwaitOp>(); omp::BarrierOp, omp::TaskwaitOp>();
if (failed(applyPartialConversion(module, target, std::move(patterns)))) if (failed(applyPartialConversion(module, target, std::move(patterns))))

View File

@ -28,3 +28,22 @@ func @branch_loop() {
} }
return return
} }
// CHECK-LABEL: @wsloop
// CHECK: (%[[ARG0:.*]]: !llvm.i64, %[[ARG1:.*]]: !llvm.i64, %[[ARG2:.*]]: !llvm.i64, %[[ARG3:.*]]: !llvm.i64, %[[ARG4:.*]]: !llvm.i64, %[[ARG5:.*]]: !llvm.i64)
func @wsloop(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index) {
// CHECK: omp.parallel
omp.parallel {
// CHECK: omp.wsloop
// CHECK: (%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]], %[[ARG5]])
"omp.wsloop"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) ( {
// CHECK: ^{{.*}}(%[[ARG6:.*]]: !llvm.i64, %[[ARG7:.*]]: !llvm.i64):
^bb0(%arg6: index, %arg7: index): // no predecessors
// CHECK: "test.payload"(%[[ARG6]], %[[ARG7]]) : (!llvm.i64, !llvm.i64) -> ()
"test.payload"(%arg6, %arg7) : (index, index) -> ()
omp.yield
}) {operand_segment_sizes = dense<[2, 2, 2, 0, 0, 0, 0, 0, 0]> : vector<9xi32>} : (index, index, index, index, index, index) -> ()
omp.terminator
}
return
}