[fir] Add fir.select and fir.select_rank FIR to LLVM IR conversion patterns

The `fir.select` and `fir.select_rank` are lowered to llvm.switch.

This patch is part of the upstreaming effort from fir-dev branch.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D113089

Co-authored-by: Jean Perier <jperier@nvidia.com>
Co-authored-by: Eric Schweitz <eschweitz@nvidia.com>
This commit is contained in:
Valentin Clement 2021-11-05 12:54:12 +01:00
parent 3a11fb572c
commit 8c23990949
No known key found for this signature in database
GPG Key ID: 086D54783C928776
4 changed files with 196 additions and 3 deletions

View File

@ -480,7 +480,7 @@ class fir_SwitchTerminatorOp<string mnemonic, list<OpTrait> traits = []> :
// The number of destination conditions that may be tested
unsigned getNumConditions() {
return (*this)->getAttrOfType<mlir::ArrayAttr>(getCasesAttr()).size();
return getCases().size();
}
// The selector is the value being tested to determine the destination
@ -488,6 +488,9 @@ class fir_SwitchTerminatorOp<string mnemonic, list<OpTrait> traits = []> :
mlir::Value getSelector(llvm::ArrayRef<mlir::Value> operands) {
return operands[0];
}
mlir::Value getSelector(mlir::ValueRange operands) {
return operands.front();
}
// The number of blocks that may be branched to
unsigned getNumDest() { return (*this)->getNumSuccessors(); }
@ -498,6 +501,8 @@ class fir_SwitchTerminatorOp<string mnemonic, list<OpTrait> traits = []> :
llvm::Optional<llvm::ArrayRef<mlir::Value>> getSuccessorOperands(
llvm::ArrayRef<mlir::Value> operands, unsigned cond);
llvm::Optional<mlir::ValueRange> getSuccessorOperands(
mlir::ValueRange operands, unsigned cond);
using BranchOpInterfaceTrait::getSuccessorOperands;
// Helper function to deal with Optional operand forms
@ -510,6 +515,10 @@ class fir_SwitchTerminatorOp<string mnemonic, list<OpTrait> traits = []> :
p.printSuccessor(succ);
}
mlir::ArrayAttr getCases() {
return (*this)->getAttrOfType<mlir::ArrayAttr>(getCasesAttr());
}
unsigned targetOffsetSize();
}];
}

View File

@ -174,6 +174,78 @@ struct GlobalOpConversion : public FIROpConversion<fir::GlobalOp> {
}
};
template <typename OP>
void selectMatchAndRewrite(fir::LLVMTypeConverter &lowering, OP select,
typename OP::Adaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) {
unsigned conds = select.getNumConditions();
auto cases = select.getCases().getValue();
mlir::Value selector = adaptor.selector();
auto loc = select.getLoc();
assert(conds > 0 && "select must have cases");
llvm::SmallVector<mlir::Block *> destinations;
llvm::SmallVector<mlir::ValueRange> destinationsOperands;
mlir::Block *defaultDestination;
mlir::ValueRange defaultOperands;
llvm::SmallVector<int32_t> caseValues;
for (unsigned t = 0; t != conds; ++t) {
mlir::Block *dest = select.getSuccessor(t);
auto destOps = select.getSuccessorOperands(adaptor.getOperands(), t);
const mlir::Attribute &attr = cases[t];
if (auto intAttr = attr.template dyn_cast<mlir::IntegerAttr>()) {
destinations.push_back(dest);
destinationsOperands.push_back(destOps.hasValue() ? *destOps
: ValueRange());
caseValues.push_back(intAttr.getInt());
continue;
}
assert(attr.template dyn_cast_or_null<mlir::UnitAttr>());
assert((t + 1 == conds) && "unit must be last");
defaultDestination = dest;
defaultOperands = destOps.hasValue() ? *destOps : ValueRange();
}
// LLVM::SwitchOp takes a i32 type for the selector.
if (select.getSelector().getType() != rewriter.getI32Type())
selector =
rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), selector);
rewriter.replaceOpWithNewOp<mlir::LLVM::SwitchOp>(
select, selector,
/*defaultDestination=*/defaultDestination,
/*defaultOperands=*/defaultOperands,
/*caseValues=*/caseValues,
/*caseDestinations=*/destinations,
/*caseOperands=*/destinationsOperands,
/*branchWeights=*/ArrayRef<int32_t>());
}
/// conversion of fir::SelectOp to an if-then-else ladder
struct SelectOpConversion : public FIROpConversion<fir::SelectOp> {
using FIROpConversion::FIROpConversion;
mlir::LogicalResult
matchAndRewrite(fir::SelectOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
selectMatchAndRewrite<fir::SelectOp>(lowerTy(), op, adaptor, rewriter);
return success();
}
};
/// conversion of fir::SelectRankOp to an if-then-else ladder
struct SelectRankOpConversion : public FIROpConversion<fir::SelectRankOp> {
using FIROpConversion::FIROpConversion;
mlir::LogicalResult
matchAndRewrite(fir::SelectRankOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
selectMatchAndRewrite<fir::SelectRankOp>(lowerTy(), op, adaptor, rewriter);
return success();
}
};
// convert to LLVM IR dialect `undef`
struct UndefOpConversion : public FIROpConversion<fir::UndefOp> {
using FIROpConversion::FIROpConversion;
@ -318,8 +390,9 @@ public:
fir::LLVMTypeConverter typeConverter{getModule()};
mlir::OwningRewritePatternList pattern(context);
pattern.insert<AddrOfOpConversion, HasValueOpConversion, GlobalOpConversion,
InsertOnRangeOpConversion, UndefOpConversion,
UnreachableOpConversion, ZeroOpConversion>(typeConverter);
InsertOnRangeOpConversion, SelectOpConversion,
SelectRankOpConversion, UnreachableOpConversion,
ZeroOpConversion, UndefOpConversion>(typeConverter);
mlir::populateStdToLLVMConversionPatterns(typeConverter, pattern);
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
pattern);

View File

@ -2264,6 +2264,15 @@ fir::SelectOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands,
return {getSubOperands(oper, getSubOperands(2, operands, segments), a)};
}
llvm::Optional<mlir::ValueRange>
fir::SelectOp::getSuccessorOperands(mlir::ValueRange operands, unsigned oper) {
auto a =
(*this)->getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr());
auto segments = (*this)->getAttrOfType<mlir::DenseIntElementsAttr>(
getOperandSegmentSizeAttr());
return {getSubOperands(oper, getSubOperands(2, operands, segments), a)};
}
unsigned fir::SelectOp::targetOffsetSize() {
return denseElementsSize((*this)->getAttrOfType<mlir::DenseIntElementsAttr>(
getTargetOffsetAttr()));
@ -2557,6 +2566,16 @@ fir::SelectRankOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands,
return {getSubOperands(oper, getSubOperands(2, operands, segments), a)};
}
llvm::Optional<mlir::ValueRange>
fir::SelectRankOp::getSuccessorOperands(mlir::ValueRange operands,
unsigned oper) {
auto a =
(*this)->getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr());
auto segments = (*this)->getAttrOfType<mlir::DenseIntElementsAttr>(
getOperandSegmentSizeAttr());
return {getSubOperands(oper, getSubOperands(2, operands, segments), a)};
}
unsigned fir::SelectRankOp::targetOffsetSize() {
return denseElementsSize((*this)->getAttrOfType<mlir::DenseIntElementsAttr>(
getTargetOffsetAttr()));

View File

@ -167,3 +167,95 @@ func @zero_test_float() {
func @test_unreachable() {
fir.unreachable
}
// -----
// Test `fir.select` operation conversion pattern.
// Check that the if-then-else ladder is correctly constructed and that we
// branch to the correct block.
func @select(%arg : index, %arg2 : i32) -> i32 {
%0 = arith.constant 1 : i32
%1 = arith.constant 2 : i32
%2 = arith.constant 3 : i32
%3 = arith.constant 4 : i32
fir.select %arg:index [ 1, ^bb1(%0:i32),
2, ^bb2(%2,%arg,%arg2:i32,index,i32),
3, ^bb3(%arg2,%2:i32,i32),
4, ^bb4(%1:i32),
unit, ^bb5 ]
^bb1(%a : i32) :
return %a : i32
^bb2(%b : i32, %b2 : index, %b3:i32) :
%castidx = arith.index_cast %b2 : index to i32
%4 = arith.addi %b, %castidx : i32
%5 = arith.addi %4, %b3 : i32
return %5 : i32
^bb3(%c:i32, %c2:i32) :
%6 = arith.addi %c, %c2 : i32
return %6 : i32
^bb4(%d : i32) :
return %d : i32
^bb5 :
%zero = arith.constant 0 : i32
return %zero : i32
}
// CHECK-LABEL: func @select(
// CHECK-SAME: %[[SELECTVALUE:.*]]: [[IDX:.*]],
// CHECK-SAME: %[[ARG1:.*]]: i32)
// CHECK: %[[C0:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[C1:.*]] = llvm.mlir.constant(2 : i32) : i32
// CHECK: %[[C2:.*]] = llvm.mlir.constant(3 : i32) : i32
// CHECK: %[[SELECTOR:.*]] = llvm.trunc %[[SELECTVALUE]] : i{{.*}} to i32
// CHECK: llvm.switch %[[SELECTOR]], ^bb5 [
// CHECK: 1: ^bb1(%[[C0]] : i32),
// CHECK: 2: ^bb2(%[[C2]], %[[SELECTVALUE]], %[[ARG1]] : i32, [[IDX]], i32),
// CHECK: 3: ^bb3(%[[ARG1]], %[[C2]] : i32, i32),
// CHECK: 4: ^bb4(%[[C1]] : i32)
// CHECK: ]
// -----
// Test `fir.select_rank` operation conversion pattern.
// Check that the if-then-else ladder is correctly constructed and that we
// branch to the correct block.
func @select_rank(%arg : i32, %arg2 : i32) -> i32 {
%0 = arith.constant 1 : i32
%1 = arith.constant 2 : i32
%2 = arith.constant 3 : i32
%3 = arith.constant 4 : i32
fir.select_rank %arg:i32 [ 1, ^bb1(%0:i32),
2, ^bb2(%2,%arg,%arg2:i32,i32,i32),
3, ^bb3(%arg2,%2:i32,i32),
4, ^bb4(%1:i32),
unit, ^bb5 ]
^bb1(%a : i32) :
return %a : i32
^bb2(%b : i32, %b2 : i32, %b3:i32) :
%4 = arith.addi %b, %b2 : i32
%5 = arith.addi %4, %b3 : i32
return %5 : i32
^bb3(%c:i32, %c2:i32) :
%6 = arith.addi %c, %c2 : i32
return %6 : i32
^bb4(%d : i32) :
return %d : i32
^bb5 :
%zero = arith.constant 0 : i32
return %zero : i32
}
// CHECK-LABEL: func @select_rank(
// CHECK-SAME: %[[SELECTVALUE:.*]]: i32,
// CHECK-SAME: %[[ARG1:.*]]: i32)
// CHECK: %[[C0:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[C1:.*]] = llvm.mlir.constant(2 : i32) : i32
// CHECK: %[[C2:.*]] = llvm.mlir.constant(3 : i32) : i32
// CHECK: llvm.switch %[[SELECTVALUE]], ^bb5 [
// CHECK: 1: ^bb1(%[[C0]] : i32),
// CHECK: 2: ^bb2(%[[C2]], %[[SELECTVALUE]], %[[ARG1]] : i32, i32, i32),
// CHECK: 3: ^bb3(%[[ARG1]], %[[C2]] : i32, i32),
// CHECK: 4: ^bb4(%[[C1]] : i32)
// CHECK: ]