diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index e69891ec9abf..ce96f6f5c3fb 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -813,6 +813,20 @@ def AtomicReadOp : OpenMP_Op<"atomic.read", [AllTypesMatch<["x", "v"]>]> { `:` type($x) attr-dict }]; let hasVerifier = 1; + let extraClassDeclaration = [{ + /// The number of variable operands. + unsigned getNumVariableOperands() { + assert(x() && "expected 'x' operand"); + assert(v() && "expected 'v' operand"); + return 2; + } + + /// The i-th variable operand passed. + Value getVariableOperand(unsigned i) { + assert(0 <= i < 2 && "invalid index position for an operand"); + return i == 0 ? x() : v(); + } + }]; } def AtomicWriteOp : OpenMP_Op<"atomic.write"> { @@ -847,6 +861,20 @@ def AtomicWriteOp : OpenMP_Op<"atomic.write"> { attr-dict }]; let hasVerifier = 1; + let extraClassDeclaration = [{ + /// The number of variable operands. + unsigned getNumVariableOperands() { + assert(address() && "expected address operand"); + assert(value() && "expected value operand"); + return 2; + } + + /// The i-th variable operand passed. + Value getVariableOperand(unsigned i) { + assert(0 <= i < 2 && "invalid index position for an operand"); + return i == 0 ? address() : value(); + } + }]; } def AtomicUpdateOp : OpenMP_Op<"atomic.update", @@ -996,6 +1024,19 @@ def ThreadprivateOp : OpenMP_Op<"threadprivate"> { let assemblyFormat = [{ $sym_addr `:` type($sym_addr) `->` type($tls_addr) attr-dict }]; + let extraClassDeclaration = [{ + /// The number of variable operands. + unsigned getNumVariableOperands() { + assert(sym_addr() && "expected one variable operand"); + return 1; + } + + /// The i-th variable operand passed. + Value getVariableOperand(unsigned i) { + assert(i == 0 && "invalid index position for an operand"); + return sym_addr(); + } + }]; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp index d16e51e0eb24..25eb2bf5ddc4 100644 --- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp +++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp @@ -52,7 +52,23 @@ struct RegionLessOpConversion : public ConvertOpToLLVMPattern { LogicalResult matchAndRewrite(T curOp, typename T::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(curOp, TypeRange(), adaptor.getOperands(), + TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter(); + SmallVector resTypes; + if (failed(converter->convertTypes(curOp->getResultTypes(), resTypes))) + return failure(); + SmallVector convertedOperands; + for (unsigned idx = 0; idx < curOp.getNumVariableOperands(); ++idx) { + Value originalVariableOperand = curOp.getVariableOperand(idx); + if (!originalVariableOperand) + return failure(); + if (originalVariableOperand.getType().isa()) { + // TODO: Support memref type in variable operands + rewriter.notifyMatchFailure(curOp, "memref is not supported yet"); + } else { + convertedOperands.emplace_back(adaptor.getOperands()[idx]); + } + } + rewriter.replaceOpWithNewOp(curOp, resTypes, convertedOperands, curOp->getAttrs()); return success(); } @@ -65,10 +81,10 @@ void mlir::configureOpenMPToLLVMConversionLegality( mlir::omp::MasterOp>( [&](Operation *op) { return typeConverter.isLegal(&op->getRegion(0)); }); target - .addDynamicallyLegalOp( - [&](Operation *op) { - return typeConverter.isLegal(op->getOperandTypes()); - }); + .addDynamicallyLegalOp([&](Operation *op) { + return typeConverter.isLegal(op->getOperandTypes()); + }); } void mlir::populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter, @@ -77,7 +93,8 @@ void mlir::populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter, RegionOpConversion, RegionOpConversion, RegionLessOpConversion, - RegionLessOpConversion>(converter); + RegionLessOpConversion, + RegionLessOpConversion>(converter); } namespace { diff --git a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir index cb1e2411e809..341e283ffd33 100644 --- a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir +++ b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -convert-openmp-to-llvm %s -split-input-file | FileCheck %s +// RUN: mlir-opt -convert-openmp-to-llvm -split-input-file %s | FileCheck %s // CHECK-LABEL: llvm.func @master_block_arg func.func @master_block_arg() { @@ -15,6 +15,8 @@ func.func @master_block_arg() { return } +// ----- + // CHECK-LABEL: llvm.func @branch_loop func.func @branch_loop() { %start = arith.constant 0 : index @@ -44,6 +46,8 @@ func.func @branch_loop() { return } +// ----- + // CHECK-LABEL: @wsloop // CHECK: (%[[ARG0:.*]]: i64, %[[ARG1:.*]]: i64, %[[ARG2:.*]]: i64, %[[ARG3:.*]]: i64, %[[ARG4:.*]]: i64, %[[ARG5:.*]]: i64) func.func @wsloop(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index) { @@ -62,3 +66,35 @@ func.func @wsloop(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: } return } + +// ----- + +// CHECK-LABEL: @atomic_write +// CHECK: (%[[ARG0:.*]]: !llvm.ptr) +// CHECK: %[[VAL0:.*]] = llvm.mlir.constant(1 : i32) : i32 +// CHECK: omp.atomic.write %[[ARG0]] = %[[VAL0]] hint(none) memory_order(relaxed) : !llvm.ptr, i32 +func.func @atomic_write(%a: !llvm.ptr) -> () { + %1 = arith.constant 1 : i32 + omp.atomic.write %a = %1 hint(none) memory_order(relaxed) : !llvm.ptr, i32 + return +} + +// ----- + +// CHECK-LABEL: @atomic_read +// CHECK: (%[[ARG0:.*]]: !llvm.ptr, %[[ARG1:.*]]: !llvm.ptr) +// CHECK: omp.atomic.read %[[ARG1]] = %[[ARG0]] memory_order(acquire) hint(contended) : !llvm.ptr +func.func @atomic_read(%a: !llvm.ptr, %b: !llvm.ptr) -> () { + omp.atomic.read %b = %a memory_order(acquire) hint(contended) : !llvm.ptr + return +} + +// ----- + +// CHECK-LABEL: @threadprivate +// CHECK: (%[[ARG0:.*]]: !llvm.ptr) +// CHECK: %[[VAL0:.*]] = omp.threadprivate %[[ARG0]] : !llvm.ptr -> !llvm.ptr +func.func @threadprivate(%a: !llvm.ptr) -> () { + %1 = omp.threadprivate %a : !llvm.ptr -> !llvm.ptr + return +}