forked from OSchip/llvm-project
[mlir] Support WsLoopOp in OpenMP to LLVM dialect conversion
It is a simple conversion that only requires to change the region argument types, generalize it from ParallelOp. Reviewed By: kiranchandramohan Differential Revision: https://reviews.llvm.org/D91989
This commit is contained in:
parent
0a20660c8f
commit
f7d033f4d8
|
@ -16,18 +16,23 @@
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
struct ParallelOpConversion : public ConvertToLLVMPattern {
|
/// A pattern that converts the region arguments in a single-region OpenMP
|
||||||
explicit ParallelOpConversion(MLIRContext *context,
|
/// operation to the LLVM dialect. The body of the region is not modified and is
|
||||||
LLVMTypeConverter &typeConverter)
|
/// expected to either be processed by the conversion infrastructure or already
|
||||||
: ConvertToLLVMPattern(omp::ParallelOp::getOperationName(), context,
|
/// contain ops compatible with LLVM dialect types.
|
||||||
|
template <typename OpType>
|
||||||
|
struct RegionOpConversion : public ConvertToLLVMPattern {
|
||||||
|
explicit RegionOpConversion(MLIRContext *context,
|
||||||
|
LLVMTypeConverter &typeConverter)
|
||||||
|
: ConvertToLLVMPattern(OpType::getOperationName(), context,
|
||||||
typeConverter) {}
|
typeConverter) {}
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
auto curOp = cast<omp::ParallelOp>(op);
|
auto curOp = cast<OpType>(op);
|
||||||
auto newOp = rewriter.create<omp::ParallelOp>(curOp.getLoc(), TypeRange(),
|
auto newOp = rewriter.create<OpType>(curOp.getLoc(), TypeRange(), operands,
|
||||||
operands, curOp.getAttrs());
|
curOp.getAttrs());
|
||||||
rewriter.inlineRegionBefore(curOp.region(), newOp.region(),
|
rewriter.inlineRegionBefore(curOp.region(), newOp.region(),
|
||||||
newOp.region().end());
|
newOp.region().end());
|
||||||
if (failed(rewriter.convertRegionTypes(&newOp.region(), typeConverter)))
|
if (failed(rewriter.convertRegionTypes(&newOp.region(), typeConverter)))
|
||||||
|
@ -42,7 +47,8 @@ struct ParallelOpConversion : public ConvertToLLVMPattern {
|
||||||
void mlir::populateOpenMPToLLVMConversionPatterns(
|
void mlir::populateOpenMPToLLVMConversionPatterns(
|
||||||
MLIRContext *context, LLVMTypeConverter &converter,
|
MLIRContext *context, LLVMTypeConverter &converter,
|
||||||
OwningRewritePatternList &patterns) {
|
OwningRewritePatternList &patterns) {
|
||||||
patterns.insert<ParallelOpConversion>(context, converter);
|
patterns.insert<RegionOpConversion<omp::ParallelOp>,
|
||||||
|
RegionOpConversion<omp::WsLoopOp>>(context, converter);
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -63,8 +69,8 @@ void ConvertOpenMPToLLVMPass::runOnOperation() {
|
||||||
populateOpenMPToLLVMConversionPatterns(context, converter, patterns);
|
populateOpenMPToLLVMConversionPatterns(context, converter, patterns);
|
||||||
|
|
||||||
LLVMConversionTarget target(getContext());
|
LLVMConversionTarget target(getContext());
|
||||||
target.addDynamicallyLegalOp<omp::ParallelOp>(
|
target.addDynamicallyLegalOp<omp::ParallelOp, omp::WsLoopOp>(
|
||||||
[&](omp::ParallelOp op) { return converter.isLegal(&op.getRegion()); });
|
[&](Operation *op) { return converter.isLegal(&op->getRegion(0)); });
|
||||||
target.addLegalOp<omp::TerminatorOp, omp::TaskyieldOp, omp::FlushOp,
|
target.addLegalOp<omp::TerminatorOp, omp::TaskyieldOp, omp::FlushOp,
|
||||||
omp::BarrierOp, omp::TaskwaitOp>();
|
omp::BarrierOp, omp::TaskwaitOp>();
|
||||||
if (failed(applyPartialConversion(module, target, std::move(patterns))))
|
if (failed(applyPartialConversion(module, target, std::move(patterns))))
|
||||||
|
|
|
@ -28,3 +28,22 @@ func @branch_loop() {
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @wsloop
|
||||||
|
// CHECK: (%[[ARG0:.*]]: !llvm.i64, %[[ARG1:.*]]: !llvm.i64, %[[ARG2:.*]]: !llvm.i64, %[[ARG3:.*]]: !llvm.i64, %[[ARG4:.*]]: !llvm.i64, %[[ARG5:.*]]: !llvm.i64)
|
||||||
|
func @wsloop(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index) {
|
||||||
|
// CHECK: omp.parallel
|
||||||
|
omp.parallel {
|
||||||
|
// CHECK: omp.wsloop
|
||||||
|
// CHECK: (%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]], %[[ARG5]])
|
||||||
|
"omp.wsloop"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) ( {
|
||||||
|
// CHECK: ^{{.*}}(%[[ARG6:.*]]: !llvm.i64, %[[ARG7:.*]]: !llvm.i64):
|
||||||
|
^bb0(%arg6: index, %arg7: index): // no predecessors
|
||||||
|
// CHECK: "test.payload"(%[[ARG6]], %[[ARG7]]) : (!llvm.i64, !llvm.i64) -> ()
|
||||||
|
"test.payload"(%arg6, %arg7) : (index, index) -> ()
|
||||||
|
omp.yield
|
||||||
|
}) {operand_segment_sizes = dense<[2, 2, 2, 0, 0, 0, 0, 0, 0]> : vector<9xi32>} : (index, index, index, index, index, index) -> ()
|
||||||
|
omp.terminator
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue