forked from OSchip/llvm-project
[fir] Add fir.select and fir.select_rank FIR to LLVM IR conversion patterns
The `fir.select` and `fir.select_rank` are lowered to llvm.switch. This patch is part of the upstreaming effort from fir-dev branch. Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D113089 Co-authored-by: Jean Perier <jperier@nvidia.com> Co-authored-by: Eric Schweitz <eschweitz@nvidia.com>
This commit is contained in:
parent
3a11fb572c
commit
8c23990949
|
@ -480,7 +480,7 @@ class fir_SwitchTerminatorOp<string mnemonic, list<OpTrait> traits = []> :
|
|||
|
||||
// The number of destination conditions that may be tested
|
||||
unsigned getNumConditions() {
|
||||
return (*this)->getAttrOfType<mlir::ArrayAttr>(getCasesAttr()).size();
|
||||
return getCases().size();
|
||||
}
|
||||
|
||||
// The selector is the value being tested to determine the destination
|
||||
|
@ -488,6 +488,9 @@ class fir_SwitchTerminatorOp<string mnemonic, list<OpTrait> traits = []> :
|
|||
mlir::Value getSelector(llvm::ArrayRef<mlir::Value> operands) {
|
||||
return operands[0];
|
||||
}
|
||||
mlir::Value getSelector(mlir::ValueRange operands) {
|
||||
return operands.front();
|
||||
}
|
||||
|
||||
// The number of blocks that may be branched to
|
||||
unsigned getNumDest() { return (*this)->getNumSuccessors(); }
|
||||
|
@ -498,6 +501,8 @@ class fir_SwitchTerminatorOp<string mnemonic, list<OpTrait> traits = []> :
|
|||
|
||||
llvm::Optional<llvm::ArrayRef<mlir::Value>> getSuccessorOperands(
|
||||
llvm::ArrayRef<mlir::Value> operands, unsigned cond);
|
||||
llvm::Optional<mlir::ValueRange> getSuccessorOperands(
|
||||
mlir::ValueRange operands, unsigned cond);
|
||||
using BranchOpInterfaceTrait::getSuccessorOperands;
|
||||
|
||||
// Helper function to deal with Optional operand forms
|
||||
|
@ -510,6 +515,10 @@ class fir_SwitchTerminatorOp<string mnemonic, list<OpTrait> traits = []> :
|
|||
p.printSuccessor(succ);
|
||||
}
|
||||
|
||||
mlir::ArrayAttr getCases() {
|
||||
return (*this)->getAttrOfType<mlir::ArrayAttr>(getCasesAttr());
|
||||
}
|
||||
|
||||
unsigned targetOffsetSize();
|
||||
}];
|
||||
}
|
||||
|
|
|
@ -174,6 +174,78 @@ struct GlobalOpConversion : public FIROpConversion<fir::GlobalOp> {
|
|||
}
|
||||
};
|
||||
|
||||
template <typename OP>
|
||||
void selectMatchAndRewrite(fir::LLVMTypeConverter &lowering, OP select,
|
||||
typename OP::Adaptor adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) {
|
||||
unsigned conds = select.getNumConditions();
|
||||
auto cases = select.getCases().getValue();
|
||||
mlir::Value selector = adaptor.selector();
|
||||
auto loc = select.getLoc();
|
||||
assert(conds > 0 && "select must have cases");
|
||||
|
||||
llvm::SmallVector<mlir::Block *> destinations;
|
||||
llvm::SmallVector<mlir::ValueRange> destinationsOperands;
|
||||
mlir::Block *defaultDestination;
|
||||
mlir::ValueRange defaultOperands;
|
||||
llvm::SmallVector<int32_t> caseValues;
|
||||
|
||||
for (unsigned t = 0; t != conds; ++t) {
|
||||
mlir::Block *dest = select.getSuccessor(t);
|
||||
auto destOps = select.getSuccessorOperands(adaptor.getOperands(), t);
|
||||
const mlir::Attribute &attr = cases[t];
|
||||
if (auto intAttr = attr.template dyn_cast<mlir::IntegerAttr>()) {
|
||||
destinations.push_back(dest);
|
||||
destinationsOperands.push_back(destOps.hasValue() ? *destOps
|
||||
: ValueRange());
|
||||
caseValues.push_back(intAttr.getInt());
|
||||
continue;
|
||||
}
|
||||
assert(attr.template dyn_cast_or_null<mlir::UnitAttr>());
|
||||
assert((t + 1 == conds) && "unit must be last");
|
||||
defaultDestination = dest;
|
||||
defaultOperands = destOps.hasValue() ? *destOps : ValueRange();
|
||||
}
|
||||
|
||||
// LLVM::SwitchOp takes a i32 type for the selector.
|
||||
if (select.getSelector().getType() != rewriter.getI32Type())
|
||||
selector =
|
||||
rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), selector);
|
||||
|
||||
rewriter.replaceOpWithNewOp<mlir::LLVM::SwitchOp>(
|
||||
select, selector,
|
||||
/*defaultDestination=*/defaultDestination,
|
||||
/*defaultOperands=*/defaultOperands,
|
||||
/*caseValues=*/caseValues,
|
||||
/*caseDestinations=*/destinations,
|
||||
/*caseOperands=*/destinationsOperands,
|
||||
/*branchWeights=*/ArrayRef<int32_t>());
|
||||
}
|
||||
|
||||
/// conversion of fir::SelectOp to an if-then-else ladder
|
||||
struct SelectOpConversion : public FIROpConversion<fir::SelectOp> {
|
||||
using FIROpConversion::FIROpConversion;
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(fir::SelectOp op, OpAdaptor adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
selectMatchAndRewrite<fir::SelectOp>(lowerTy(), op, adaptor, rewriter);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// conversion of fir::SelectRankOp to an if-then-else ladder
|
||||
struct SelectRankOpConversion : public FIROpConversion<fir::SelectRankOp> {
|
||||
using FIROpConversion::FIROpConversion;
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(fir::SelectRankOp op, OpAdaptor adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
selectMatchAndRewrite<fir::SelectRankOp>(lowerTy(), op, adaptor, rewriter);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// convert to LLVM IR dialect `undef`
|
||||
struct UndefOpConversion : public FIROpConversion<fir::UndefOp> {
|
||||
using FIROpConversion::FIROpConversion;
|
||||
|
@ -318,8 +390,9 @@ public:
|
|||
fir::LLVMTypeConverter typeConverter{getModule()};
|
||||
mlir::OwningRewritePatternList pattern(context);
|
||||
pattern.insert<AddrOfOpConversion, HasValueOpConversion, GlobalOpConversion,
|
||||
InsertOnRangeOpConversion, UndefOpConversion,
|
||||
UnreachableOpConversion, ZeroOpConversion>(typeConverter);
|
||||
InsertOnRangeOpConversion, SelectOpConversion,
|
||||
SelectRankOpConversion, UnreachableOpConversion,
|
||||
ZeroOpConversion, UndefOpConversion>(typeConverter);
|
||||
mlir::populateStdToLLVMConversionPatterns(typeConverter, pattern);
|
||||
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
|
||||
pattern);
|
||||
|
|
|
@ -2264,6 +2264,15 @@ fir::SelectOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands,
|
|||
return {getSubOperands(oper, getSubOperands(2, operands, segments), a)};
|
||||
}
|
||||
|
||||
llvm::Optional<mlir::ValueRange>
|
||||
fir::SelectOp::getSuccessorOperands(mlir::ValueRange operands, unsigned oper) {
|
||||
auto a =
|
||||
(*this)->getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr());
|
||||
auto segments = (*this)->getAttrOfType<mlir::DenseIntElementsAttr>(
|
||||
getOperandSegmentSizeAttr());
|
||||
return {getSubOperands(oper, getSubOperands(2, operands, segments), a)};
|
||||
}
|
||||
|
||||
unsigned fir::SelectOp::targetOffsetSize() {
|
||||
return denseElementsSize((*this)->getAttrOfType<mlir::DenseIntElementsAttr>(
|
||||
getTargetOffsetAttr()));
|
||||
|
@ -2557,6 +2566,16 @@ fir::SelectRankOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands,
|
|||
return {getSubOperands(oper, getSubOperands(2, operands, segments), a)};
|
||||
}
|
||||
|
||||
llvm::Optional<mlir::ValueRange>
|
||||
fir::SelectRankOp::getSuccessorOperands(mlir::ValueRange operands,
|
||||
unsigned oper) {
|
||||
auto a =
|
||||
(*this)->getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr());
|
||||
auto segments = (*this)->getAttrOfType<mlir::DenseIntElementsAttr>(
|
||||
getOperandSegmentSizeAttr());
|
||||
return {getSubOperands(oper, getSubOperands(2, operands, segments), a)};
|
||||
}
|
||||
|
||||
unsigned fir::SelectRankOp::targetOffsetSize() {
|
||||
return denseElementsSize((*this)->getAttrOfType<mlir::DenseIntElementsAttr>(
|
||||
getTargetOffsetAttr()));
|
||||
|
|
|
@ -167,3 +167,95 @@ func @zero_test_float() {
|
|||
func @test_unreachable() {
|
||||
fir.unreachable
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test `fir.select` operation conversion pattern.
|
||||
// Check that the if-then-else ladder is correctly constructed and that we
|
||||
// branch to the correct block.
|
||||
|
||||
func @select(%arg : index, %arg2 : i32) -> i32 {
|
||||
%0 = arith.constant 1 : i32
|
||||
%1 = arith.constant 2 : i32
|
||||
%2 = arith.constant 3 : i32
|
||||
%3 = arith.constant 4 : i32
|
||||
fir.select %arg:index [ 1, ^bb1(%0:i32),
|
||||
2, ^bb2(%2,%arg,%arg2:i32,index,i32),
|
||||
3, ^bb3(%arg2,%2:i32,i32),
|
||||
4, ^bb4(%1:i32),
|
||||
unit, ^bb5 ]
|
||||
^bb1(%a : i32) :
|
||||
return %a : i32
|
||||
^bb2(%b : i32, %b2 : index, %b3:i32) :
|
||||
%castidx = arith.index_cast %b2 : index to i32
|
||||
%4 = arith.addi %b, %castidx : i32
|
||||
%5 = arith.addi %4, %b3 : i32
|
||||
return %5 : i32
|
||||
^bb3(%c:i32, %c2:i32) :
|
||||
%6 = arith.addi %c, %c2 : i32
|
||||
return %6 : i32
|
||||
^bb4(%d : i32) :
|
||||
return %d : i32
|
||||
^bb5 :
|
||||
%zero = arith.constant 0 : i32
|
||||
return %zero : i32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @select(
|
||||
// CHECK-SAME: %[[SELECTVALUE:.*]]: [[IDX:.*]],
|
||||
// CHECK-SAME: %[[ARG1:.*]]: i32)
|
||||
// CHECK: %[[C0:.*]] = llvm.mlir.constant(1 : i32) : i32
|
||||
// CHECK: %[[C1:.*]] = llvm.mlir.constant(2 : i32) : i32
|
||||
// CHECK: %[[C2:.*]] = llvm.mlir.constant(3 : i32) : i32
|
||||
// CHECK: %[[SELECTOR:.*]] = llvm.trunc %[[SELECTVALUE]] : i{{.*}} to i32
|
||||
// CHECK: llvm.switch %[[SELECTOR]], ^bb5 [
|
||||
// CHECK: 1: ^bb1(%[[C0]] : i32),
|
||||
// CHECK: 2: ^bb2(%[[C2]], %[[SELECTVALUE]], %[[ARG1]] : i32, [[IDX]], i32),
|
||||
// CHECK: 3: ^bb3(%[[ARG1]], %[[C2]] : i32, i32),
|
||||
// CHECK: 4: ^bb4(%[[C1]] : i32)
|
||||
// CHECK: ]
|
||||
|
||||
// -----
|
||||
|
||||
// Test `fir.select_rank` operation conversion pattern.
|
||||
// Check that the if-then-else ladder is correctly constructed and that we
|
||||
// branch to the correct block.
|
||||
|
||||
func @select_rank(%arg : i32, %arg2 : i32) -> i32 {
|
||||
%0 = arith.constant 1 : i32
|
||||
%1 = arith.constant 2 : i32
|
||||
%2 = arith.constant 3 : i32
|
||||
%3 = arith.constant 4 : i32
|
||||
fir.select_rank %arg:i32 [ 1, ^bb1(%0:i32),
|
||||
2, ^bb2(%2,%arg,%arg2:i32,i32,i32),
|
||||
3, ^bb3(%arg2,%2:i32,i32),
|
||||
4, ^bb4(%1:i32),
|
||||
unit, ^bb5 ]
|
||||
^bb1(%a : i32) :
|
||||
return %a : i32
|
||||
^bb2(%b : i32, %b2 : i32, %b3:i32) :
|
||||
%4 = arith.addi %b, %b2 : i32
|
||||
%5 = arith.addi %4, %b3 : i32
|
||||
return %5 : i32
|
||||
^bb3(%c:i32, %c2:i32) :
|
||||
%6 = arith.addi %c, %c2 : i32
|
||||
return %6 : i32
|
||||
^bb4(%d : i32) :
|
||||
return %d : i32
|
||||
^bb5 :
|
||||
%zero = arith.constant 0 : i32
|
||||
return %zero : i32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @select_rank(
|
||||
// CHECK-SAME: %[[SELECTVALUE:.*]]: i32,
|
||||
// CHECK-SAME: %[[ARG1:.*]]: i32)
|
||||
// CHECK: %[[C0:.*]] = llvm.mlir.constant(1 : i32) : i32
|
||||
// CHECK: %[[C1:.*]] = llvm.mlir.constant(2 : i32) : i32
|
||||
// CHECK: %[[C2:.*]] = llvm.mlir.constant(3 : i32) : i32
|
||||
// CHECK: llvm.switch %[[SELECTVALUE]], ^bb5 [
|
||||
// CHECK: 1: ^bb1(%[[C0]] : i32),
|
||||
// CHECK: 2: ^bb2(%[[C2]], %[[SELECTVALUE]], %[[ARG1]] : i32, i32, i32),
|
||||
// CHECK: 3: ^bb3(%[[ARG1]], %[[C2]] : i32, i32),
|
||||
// CHECK: 4: ^bb4(%[[C1]] : i32)
|
||||
// CHECK: ]
|
||||
|
|
Loading…
Reference in New Issue