forked from OSchip/llvm-project
[mlir] Support WsLoopOp in OpenMP to LLVM dialect conversion
It is a simple conversion that only requires to change the region argument types, generalize it from ParallelOp. Reviewed By: kiranchandramohan Differential Revision: https://reviews.llvm.org/D91989
This commit is contained in:
parent
0a20660c8f
commit
f7d033f4d8
|
@ -16,18 +16,23 @@
|
|||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
struct ParallelOpConversion : public ConvertToLLVMPattern {
|
||||
explicit ParallelOpConversion(MLIRContext *context,
|
||||
LLVMTypeConverter &typeConverter)
|
||||
: ConvertToLLVMPattern(omp::ParallelOp::getOperationName(), context,
|
||||
/// A pattern that converts the region arguments in a single-region OpenMP
|
||||
/// operation to the LLVM dialect. The body of the region is not modified and is
|
||||
/// expected to either be processed by the conversion infrastructure or already
|
||||
/// contain ops compatible with LLVM dialect types.
|
||||
template <typename OpType>
|
||||
struct RegionOpConversion : public ConvertToLLVMPattern {
|
||||
explicit RegionOpConversion(MLIRContext *context,
|
||||
LLVMTypeConverter &typeConverter)
|
||||
: ConvertToLLVMPattern(OpType::getOperationName(), context,
|
||||
typeConverter) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto curOp = cast<omp::ParallelOp>(op);
|
||||
auto newOp = rewriter.create<omp::ParallelOp>(curOp.getLoc(), TypeRange(),
|
||||
operands, curOp.getAttrs());
|
||||
auto curOp = cast<OpType>(op);
|
||||
auto newOp = rewriter.create<OpType>(curOp.getLoc(), TypeRange(), operands,
|
||||
curOp.getAttrs());
|
||||
rewriter.inlineRegionBefore(curOp.region(), newOp.region(),
|
||||
newOp.region().end());
|
||||
if (failed(rewriter.convertRegionTypes(&newOp.region(), typeConverter)))
|
||||
|
@ -42,7 +47,8 @@ struct ParallelOpConversion : public ConvertToLLVMPattern {
|
|||
void mlir::populateOpenMPToLLVMConversionPatterns(
|
||||
MLIRContext *context, LLVMTypeConverter &converter,
|
||||
OwningRewritePatternList &patterns) {
|
||||
patterns.insert<ParallelOpConversion>(context, converter);
|
||||
patterns.insert<RegionOpConversion<omp::ParallelOp>,
|
||||
RegionOpConversion<omp::WsLoopOp>>(context, converter);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
@ -63,8 +69,8 @@ void ConvertOpenMPToLLVMPass::runOnOperation() {
|
|||
populateOpenMPToLLVMConversionPatterns(context, converter, patterns);
|
||||
|
||||
LLVMConversionTarget target(getContext());
|
||||
target.addDynamicallyLegalOp<omp::ParallelOp>(
|
||||
[&](omp::ParallelOp op) { return converter.isLegal(&op.getRegion()); });
|
||||
target.addDynamicallyLegalOp<omp::ParallelOp, omp::WsLoopOp>(
|
||||
[&](Operation *op) { return converter.isLegal(&op->getRegion(0)); });
|
||||
target.addLegalOp<omp::TerminatorOp, omp::TaskyieldOp, omp::FlushOp,
|
||||
omp::BarrierOp, omp::TaskwaitOp>();
|
||||
if (failed(applyPartialConversion(module, target, std::move(patterns))))
|
||||
|
|
|
@ -28,3 +28,22 @@ func @branch_loop() {
|
|||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @wsloop
|
||||
// CHECK: (%[[ARG0:.*]]: !llvm.i64, %[[ARG1:.*]]: !llvm.i64, %[[ARG2:.*]]: !llvm.i64, %[[ARG3:.*]]: !llvm.i64, %[[ARG4:.*]]: !llvm.i64, %[[ARG5:.*]]: !llvm.i64)
|
||||
func @wsloop(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index) {
|
||||
// CHECK: omp.parallel
|
||||
omp.parallel {
|
||||
// CHECK: omp.wsloop
|
||||
// CHECK: (%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]], %[[ARG5]])
|
||||
"omp.wsloop"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) ( {
|
||||
// CHECK: ^{{.*}}(%[[ARG6:.*]]: !llvm.i64, %[[ARG7:.*]]: !llvm.i64):
|
||||
^bb0(%arg6: index, %arg7: index): // no predecessors
|
||||
// CHECK: "test.payload"(%[[ARG6]], %[[ARG7]]) : (!llvm.i64, !llvm.i64) -> ()
|
||||
"test.payload"(%arg6, %arg7) : (index, index) -> ()
|
||||
omp.yield
|
||||
}) {operand_segment_sizes = dense<[2, 2, 2, 0, 0, 0, 0, 0, 0]> : vector<9xi32>} : (index, index, index, index, index, index) -> ()
|
||||
omp.terminator
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue