[mlir] Only conditionally lower CF branching ops to LLVM

Previously cf.br cf.cond_br and cf.switch always lowered to their LLVM
equivalents. These ops are all ops that take in some values of given
types and jump to other blocks with argument lists of the same types. If
the types are not the same, a verification failure will later occur. This led
to confusions, as everything works when func->llvm and cf->llvm lowering
both occur because func->llvm updates the blocks and argument lists
while cf->llvm updates the branching ops. Without func->llvm though,
there will potentially be a type mismatch.

This change now only lowers the CF ops if they will later pass
verification. This is possible because the parent op and its blocks will
be updated before the contained branching ops, so they can test their
new operand types against the types of the blocks they jump to.

Another plan was to have func->llvm only update the entry block
signature and to allow cf->llvm to update all other blocks, but this had
2 problems:
1. This would create a FuncOp lowering in cf->llvm lowering which is
   awkward
2. This new pattern would only be applied if the containing FuncOp is
   marked invalid. This is infeasible with the shared LLVM type
   conversion/target infrastructure.

See previous discussions at
https://discourse.llvm.org/t/lowering-cf-to-llvm/63863 and
https://github.com/llvm/llvm-project/issues/55301

Differential Revision: https://reviews.llvm.org/D130971
This commit is contained in:
Tres Popp 2022-08-02 10:18:52 +02:00
parent d0541b4700
commit 448adfee05
3 changed files with 145 additions and 20 deletions

View File

@ -16,6 +16,14 @@ are expected to closely match the corresponding LLVM IR instructions and
intrinsics. This minimizes the dependency on LLVM IR libraries in MLIR as well
as reduces the churn in case of changes.
Note that many different dialects can be lowered to LLVM but are provided as
different sets of patterns and have different passes available to mlir-opt.
However, this is primarily useful for testing and prototyping, and using the
collection of patterns together is highly recommended. One place this is
important and visible is the ControlFlow dialect's branching operations which
will fail to apply if their types mismatch with the blocks they jump to in the
parent op.
SPIR-V to LLVM dialect conversion has a
[dedicated document](SPIRVToLLVMDialectConversion.md).

View File

@ -22,6 +22,7 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/StringRef.h"
#include <functional>
using namespace mlir;
@ -71,34 +72,108 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
}
};
// Base class for LLVM IR lowering terminator operations with successors.
template <typename SourceOp, typename TargetOp>
struct OneToOneLLVMTerminatorLowering
: public ConvertOpToLLVMPattern<SourceOp> {
using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
using Base = OneToOneLLVMTerminatorLowering<SourceOp, TargetOp>;
/// The cf->LLVM lowerings for branching ops require that the blocks they jump
/// to first have updated types which should be handled by a pattern operating
/// on the parent op.
static LogicalResult verifyMatchingValues(ConversionPatternRewriter &rewriter,
ValueRange operands,
ValueRange blockArgs, Location loc,
llvm::StringRef messagePrefix) {
for (const auto &idxAndTypes :
llvm::enumerate(llvm::zip(blockArgs, operands))) {
int64_t i = idxAndTypes.index();
Value argValue =
rewriter.getRemappedValue(std::get<0>(idxAndTypes.value()));
Type operandType = std::get<1>(idxAndTypes.value()).getType();
// In the case of an invalid jump, the block argument will have been
// remapped to an UnrealizedConversionCast. In the case of a valid jump,
// there might still be a no-op conversion cast with both types being equal.
// Consider both of these details to see if the jump would be invalid.
if (auto op = dyn_cast_or_null<UnrealizedConversionCastOp>(
argValue.getDefiningOp())) {
if (op.getOperandTypes().front() != operandType) {
return rewriter.notifyMatchFailure(loc, [&](Diagnostic &diag) {
diag << messagePrefix;
diag << "mismatched types from operand # " << i << " ";
diag << operandType;
diag << " not compatible with destination block argument type ";
diag << argValue.getType();
diag << " which should be converted with the parent op.";
});
}
}
}
return success();
}
/// Ensure that all block types were updated and then create an LLVM::BrOp
struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> {
using ConvertOpToLLVMPattern<cf::BranchOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
matchAndRewrite(cf::BranchOp op, typename cf::BranchOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<TargetOp>(op, adaptor.getOperands(),
op->getSuccessors(), op->getAttrs());
if (failed(verifyMatchingValues(rewriter, adaptor.getDestOperands(),
op.getSuccessor()->getArguments(),
op.getLoc(),
/*messagePrefix=*/"")))
return failure();
rewriter.replaceOpWithNewOp<LLVM::BrOp>(
op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs());
return success();
}
};
// FIXME: this should be tablegen'ed as well.
struct BranchOpLowering
: public OneToOneLLVMTerminatorLowering<cf::BranchOp, LLVM::BrOp> {
using Base::Base;
/// Ensure that all block types were updated and then create an LLVM::CondBrOp
struct CondBranchOpLowering : public ConvertOpToLLVMPattern<cf::CondBranchOp> {
using ConvertOpToLLVMPattern<cf::CondBranchOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(cf::CondBranchOp op,
typename cf::CondBranchOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyMatchingValues(rewriter, adaptor.getFalseDestOperands(),
op.getFalseDest()->getArguments(),
op.getLoc(), "in false case branch ")))
return failure();
if (failed(verifyMatchingValues(rewriter, adaptor.getTrueDestOperands(),
op.getTrueDest()->getArguments(),
op.getLoc(), "in true case branch ")))
return failure();
rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs());
return success();
}
};
struct CondBranchOpLowering
: public OneToOneLLVMTerminatorLowering<cf::CondBranchOp, LLVM::CondBrOp> {
using Base::Base;
};
struct SwitchOpLowering
: public OneToOneLLVMTerminatorLowering<cf::SwitchOp, LLVM::SwitchOp> {
using Base::Base;
/// Ensure that all block types were updated and then create an LLVM::SwitchOp
struct SwitchOpLowering : public ConvertOpToLLVMPattern<cf::SwitchOp> {
using ConvertOpToLLVMPattern<cf::SwitchOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(cf::SwitchOp op, typename cf::SwitchOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyMatchingValues(rewriter, adaptor.getDefaultOperands(),
op.getDefaultDestination()->getArguments(),
op.getLoc(), "in switch default case ")))
return failure();
for (const auto &i : llvm::enumerate(
llvm::zip(adaptor.getCaseOperands(), op.getCaseDestinations()))) {
if (failed(verifyMatchingValues(
rewriter, std::get<0>(i.value()),
std::get<1>(i.value())->getArguments(), op.getLoc(),
"in switch case " + std::to_string(i.index()) + " "))) {
return failure();
}
}
rewriter.replaceOpWithNewOp<LLVM::SwitchOp>(
op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs());
return success();
}
};
} // namespace

View File

@ -0,0 +1,42 @@
// RUN: mlir-opt %s -convert-cf-to-llvm | FileCheck %s
func.func @name(%flag: i32, %pred: i1){
// Test cf.br lowering failure with type mismatch
// CHECK: cf.br
%c0 = arith.constant 0 : index
cf.br ^bb1(%c0 : index)
// Test cf.cond_br lowering failure with type mismatch in false_dest
// CHECK: cf.cond_br
^bb1(%0: index): // 2 preds: ^bb0, ^bb2
%c1 = arith.constant 1 : i1
%c2 = arith.constant 1 : index
cf.cond_br %pred, ^bb2(%c1: i1), ^bb3(%c2: index)
// Test cf.cond_br lowering failure with type mismatch in true_dest
// CHECK: cf.cond_br
^bb2(%1: i1):
%c3 = arith.constant 1 : i1
%c4 = arith.constant 1 : index
cf.cond_br %pred, ^bb3(%c4: index), ^bb2(%c3: i1)
// Test cf.switch lowering failure with type mismatch in default case
// CHECK: cf.switch
^bb3(%2: index): // pred: ^bb1
%c5 = arith.constant 1 : i1
%c6 = arith.constant 1 : index
cf.switch %flag : i32, [
default: ^bb1(%c6 : index),
42: ^bb4(%c5 : i1)
]
// Test cf.switch lowering failure with type mismatch in non-default case
// CHECK: cf.switch
^bb4(%3: i1): // pred: ^bb1
%c7 = arith.constant 1 : i1
%c8 = arith.constant 1 : index
cf.switch %flag : i32, [
default: ^bb2(%c7 : i1),
41: ^bb1(%c8 : index)
]
}