[MLIR] Add a switch operation to the standard dialect

This is similar to the definition of llvm.switch, providing
unstructured branch-based control flow. It differs from the LLVM
operation in that it accepts any signless integer (not only an i32),
takes no branch weights (the same as the Branch and CondBranch ops),
and has a slightly different syntax for the default case that includes
it in the list of cases with an explicit `default` keyword.

Also included are several canonicalizers.

See https://llvm.discourse.group/t/rfc-add-std-switch-and-scf-switch/3090

Reviewed By: rriddle, bondhugula

Differential Revision: https://reviews.llvm.org/D99925
This commit is contained in:
Geoffrey Martin-Noble 2021-04-12 17:01:30 -07:00
parent d926498d9f
commit ae33eef505
6 changed files with 939 additions and 6 deletions

View File

@ -2030,6 +2030,89 @@ def SubTensorInsertOp : BaseOpWithOffsetSizesAndStrides<
let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
// SwitchOp
//===----------------------------------------------------------------------===//
def SwitchOp : Std_Op<"switch",
[AttrSizedOperandSegments,
DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
NoSideEffect, Terminator]> {
let summary = "switch operation";
let description = [{
The `switch` terminator operation represents a switch on a signless integer
value. If the flag matches one of the specified cases, then the
corresponding destination is jumped to. If the flag does not match any of
the cases, the default destination is jumped to. The count and types of
operands must align with the arguments in the corresponding target blocks.
Example:
```mlir
switch %flag : i32, [
default: ^bb1(%a : i32),
42: ^bb1(%b : i32),
43: ^bb3(%c : i32)
]
```
}];
let arguments = (ins AnyInteger:$flag,
Variadic<AnyType>:$defaultOperands,
Variadic<AnyType>:$caseOperands,
OptionalAttr<AnyIntElementsAttr>:$case_values,
OptionalAttr<I32ElementsAttr>:$case_operand_offsets);
let successors = (successor
AnySuccessor:$defaultDestination,
VariadicSuccessor<AnySuccessor>:$caseDestinations);
let builders = [
OpBuilder<(ins "Value":$flag,
"Block *":$defaultDestination,
"ValueRange":$defaultOperands,
CArg<"ArrayRef<APInt>", "{}">:$caseValues,
CArg<"BlockRange", "{}">:$caseDestinations,
CArg<"ArrayRef<ValueRange>", "{}">:$caseOperands)>,
OpBuilder<(ins "Value":$flag,
"Block *":$defaultDestination,
"ValueRange":$defaultOperands,
CArg<"ArrayRef<int32_t>", "{}">:$caseValues,
CArg<"BlockRange", "{}">:$caseDestinations,
CArg<"ArrayRef<ValueRange>", "{}">:$caseOperands)>,
OpBuilder<(ins "Value":$flag,
"Block *":$defaultDestination,
"ValueRange":$defaultOperands,
CArg<"DenseIntElementsAttr", "{}">:$caseValues,
CArg<"BlockRange", "{}">:$caseDestinations,
CArg<"ArrayRef<ValueRange>", "{}">:$caseOperands)>
];
let assemblyFormat = [{
$flag `:` type($flag) `,` `[` `\n`
custom<SwitchOpCases>(ref(type($flag)),$defaultDestination,
$defaultOperands,
type($defaultOperands),
$case_values,
$caseDestinations,
$caseOperands,
type($caseOperands),
$case_operand_offsets)
`]`
attr-dict
}];
let extraClassDeclaration = [{
/// Return the operands for the case destination block at the given index.
OperandRange getCaseOperands(unsigned index);
/// Return a mutable range of operands for the case destination block at the
/// given index.
MutableOperandRange getCaseOperandsMutable(unsigned index);
}];
let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
// TruncateIOp
//===----------------------------------------------------------------------===//

View File

@ -1333,13 +1333,15 @@ def IndexElementsAttr
.isIndex()}]>,
"index elements attribute">;
class AnyIntElementsAttr<int width> : IntElementsAttrBase<
def AnyIntElementsAttr : IntElementsAttrBase<CPred<"true">, "integer elements attribute">;
class IntElementsAttrOf<int width> : IntElementsAttrBase<
CPred<"$_self.cast<::mlir::DenseIntElementsAttr>().getType()."
"getElementType().isInteger(" # width # ")">,
width # "-bit integer elements attribute">;
def AnyI32ElementsAttr : AnyIntElementsAttr<32>;
def AnyI64ElementsAttr : AnyIntElementsAttr<64>;
def AnyI32ElementsAttr : IntElementsAttrOf<32>;
def AnyI64ElementsAttr : IntElementsAttrOf<64>;
class SignlessIntElementsAttr<int width> : IntElementsAttrBase<
CPred<"$_self.cast<::mlir::DenseIntElementsAttr>().getType()."

View File

@ -441,8 +441,9 @@ static LogicalResult verify(AtomicYieldOp op) {
/// Given a successor, try to collapse it to a new destination if it only
/// contains a passthrough unconditional branch. If the successor is
/// collapsable, `successor` and `successorOperands` are updated to reference
/// the new destination and values. `argStorage` is an optional storage to use
/// if operands to the collapsed successor need to be remapped.
/// the new destination and values. `argStorage` is used as storage if operands
/// to the collapsed successor need to be remapped. It must outlive uses of
/// successorOperands.
static LogicalResult collapseBranch(Block *&successor,
ValueRange &successorOperands,
SmallVectorImpl<Value> &argStorage) {
@ -2160,6 +2161,490 @@ void SubTensorInsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
SubTensorInsertOpCastFolder>(context);
}
//===----------------------------------------------------------------------===//
// SwitchOp
//===----------------------------------------------------------------------===//
void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
Block *defaultDestination, ValueRange defaultOperands,
DenseIntElementsAttr caseValues,
BlockRange caseDestinations,
ArrayRef<ValueRange> caseOperands) {
SmallVector<Value> flattenedCaseOperands;
SmallVector<int32_t> caseOperandOffsets;
int32_t offset = 0;
for (ValueRange operands : caseOperands) {
flattenedCaseOperands.append(operands.begin(), operands.end());
caseOperandOffsets.push_back(offset);
offset += operands.size();
}
DenseIntElementsAttr caseOperandOffsetsAttr;
if (!caseOperandOffsets.empty())
caseOperandOffsetsAttr = builder.getI32VectorAttr(caseOperandOffsets);
build(builder, result, value, defaultOperands, flattenedCaseOperands,
caseValues, caseOperandOffsetsAttr, defaultDestination,
caseDestinations);
}
void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
Block *defaultDestination, ValueRange defaultOperands,
ArrayRef<APInt> caseValues, BlockRange caseDestinations,
ArrayRef<ValueRange> caseOperands) {
DenseIntElementsAttr caseValuesAttr;
if (!caseValues.empty()) {
ShapedType caseValueType = VectorType::get(
static_cast<int64_t>(caseValues.size()), value.getType());
caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
}
build(builder, result, value, defaultDestination, defaultOperands,
caseValuesAttr, caseDestinations, caseOperands);
}
/// <cases> ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)?
/// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )*
static ParseResult
parseSwitchOpCases(OpAsmParser &parser, Type &flagType,
Block *&defaultDestination,
SmallVectorImpl<OpAsmParser::OperandType> &defaultOperands,
SmallVectorImpl<Type> &defaultOperandTypes,
DenseIntElementsAttr &caseValues,
SmallVectorImpl<Block *> &caseDestinations,
SmallVectorImpl<OpAsmParser::OperandType> &caseOperands,
SmallVectorImpl<Type> &caseOperandTypes,
DenseIntElementsAttr &caseOperandOffsets) {
if (failed(parser.parseKeyword("default")) || failed(parser.parseColon()) ||
failed(parser.parseSuccessor(defaultDestination)))
return failure();
if (succeeded(parser.parseOptionalLParen())) {
if (failed(parser.parseRegionArgumentList(defaultOperands)) ||
failed(parser.parseColonTypeList(defaultOperandTypes)) ||
failed(parser.parseRParen()))
return failure();
}
SmallVector<APInt> values;
SmallVector<int32_t> offsets;
unsigned bitWidth = flagType.getIntOrFloatBitWidth();
int64_t offset = 0;
while (succeeded(parser.parseOptionalComma())) {
int64_t value = 0;
if (failed(parser.parseInteger(value)))
return failure();
values.push_back(APInt(bitWidth, value));
Block *destination;
SmallVector<OpAsmParser::OperandType> operands;
if (failed(parser.parseColon()) ||
failed(parser.parseSuccessor(destination)))
return failure();
if (succeeded(parser.parseOptionalLParen())) {
if (failed(parser.parseRegionArgumentList(operands)) ||
failed(parser.parseColonTypeList(caseOperandTypes)) ||
failed(parser.parseRParen()))
return failure();
}
caseDestinations.push_back(destination);
caseOperands.append(operands.begin(), operands.end());
offsets.push_back(offset);
offset += operands.size();
}
if (values.empty())
return success();
Builder &builder = parser.getBuilder();
ShapedType caseValueType =
VectorType::get(static_cast<int64_t>(values.size()), flagType);
caseValues = DenseIntElementsAttr::get(caseValueType, values);
caseOperandOffsets = builder.getI32VectorAttr(offsets);
return success();
}
static void printSwitchOpCases(
OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination,
OperandRange defaultOperands, TypeRange defaultOperandTypes,
DenseIntElementsAttr caseValues, SuccessorRange caseDestinations,
OperandRange caseOperands, TypeRange caseOperandTypes,
ElementsAttr caseOperandOffsets) {
p << " default: ";
p.printSuccessorAndUseList(defaultDestination, defaultOperands);
if (!caseValues)
return;
for (int64_t i = 0, size = caseValues.size(); i < size; ++i) {
p << ',';
p.printNewline();
p << " ";
p << caseValues.getValue<APInt>(i).getLimitedValue();
p << ": ";
p.printSuccessorAndUseList(caseDestinations[i], op.getCaseOperands(i));
}
p.printNewline();
}
static LogicalResult verify(SwitchOp op) {
auto caseValues = op.case_values();
auto caseDestinations = op.caseDestinations();
if (!caseValues && caseDestinations.empty())
return success();
Type flagType = op.flag().getType();
Type caseValueType = caseValues->getType().getElementType();
if (caseValueType != flagType)
return op.emitOpError()
<< "'flag' type (" << flagType << ") should match case value type ("
<< caseValueType << ")";
if (caseValues &&
caseValues->size() != static_cast<int64_t>(caseDestinations.size()))
return op.emitOpError() << "number of case values (" << caseValues->size()
<< ") should match number of "
"case destinations ("
<< caseDestinations.size() << ")";
return success();
}
OperandRange SwitchOp::getCaseOperands(unsigned index) {
return getCaseOperandsMutable(index);
}
MutableOperandRange SwitchOp::getCaseOperandsMutable(unsigned index) {
MutableOperandRange caseOperands = caseOperandsMutable();
if (!case_operand_offsets()) {
assert(caseOperands.size() == 0 &&
"non-empty case operands must have offsets");
return caseOperands;
}
ElementsAttr offsets = case_operand_offsets().getValue();
assert(index < offsets.size() && "invalid case operand offset index");
int64_t begin = offsets.getValue(index).cast<IntegerAttr>().getInt();
int64_t end = index + 1 == offsets.size()
? caseOperands.size()
: offsets.getValue(index + 1).cast<IntegerAttr>().getInt();
return caseOperandsMutable().slice(begin, end - begin);
}
Optional<MutableOperandRange>
SwitchOp::getMutableSuccessorOperands(unsigned index) {
assert(index < getNumSuccessors() && "invalid successor index");
return index == 0 ? defaultOperandsMutable()
: getCaseOperandsMutable(index - 1);
}
Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
Optional<DenseIntElementsAttr> caseValues = case_values();
if (!caseValues)
return defaultDestination();
SuccessorRange caseDests = caseDestinations();
if (auto value = operands.front().dyn_cast_or_null<IntegerAttr>()) {
for (int64_t i = 0, size = case_values()->size(); i < size; ++i)
if (value == caseValues->getValue<IntegerAttr>(i))
return caseDests[i];
return defaultDestination();
}
return nullptr;
}
/// switch %flag : i32, [
/// default: ^bb1
/// ]
/// -> br ^bb1
static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op,
PatternRewriter &rewriter) {
if (!op.caseDestinations().empty())
return failure();
rewriter.replaceOpWithNewOp<BranchOp>(op, op.defaultDestination(),
op.defaultOperands());
return success();
}
/// switch %flag : i32, [
/// default: ^bb1,
/// 42: ^bb1,
/// 43: ^bb2
/// ]
/// ->
/// switch %flag : i32, [
/// default: ^bb1,
/// 43: ^bb2
/// ]
static LogicalResult
dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter) {
SmallVector<Block *> newCaseDestinations;
SmallVector<ValueRange> newCaseOperands;
SmallVector<APInt> newCaseValues;
bool requiresChange = false;
auto caseValues = op.case_values();
auto caseDests = op.caseDestinations();
for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
if (caseDests[i] == op.defaultDestination() &&
op.getCaseOperands(i) == op.defaultOperands()) {
requiresChange = true;
continue;
}
newCaseDestinations.push_back(caseDests[i]);
newCaseOperands.push_back(op.getCaseOperands(i));
newCaseValues.push_back(caseValues->getValue<APInt>(i));
}
if (!requiresChange)
return failure();
rewriter.replaceOpWithNewOp<SwitchOp>(op, op.flag(), op.defaultDestination(),
op.defaultOperands(), newCaseValues,
newCaseDestinations, newCaseOperands);
return success();
}
/// Helper for folding a switch with a constant value.
/// switch %c_42 : i32, [
/// default: ^bb1 ,
/// 42: ^bb2,
/// 43: ^bb3
/// ]
/// -> br ^bb2
static void foldSwitch(SwitchOp op, PatternRewriter &rewriter,
APInt caseValue) {
auto caseValues = op.case_values();
for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
if (caseValues->getValue<APInt>(i) == caseValue) {
rewriter.replaceOpWithNewOp<BranchOp>(op, op.caseDestinations()[i],
op.getCaseOperands(i));
return;
}
}
rewriter.replaceOpWithNewOp<BranchOp>(op, op.defaultDestination(),
op.defaultOperands());
}
/// switch %c_42 : i32, [
/// default: ^bb1,
/// 42: ^bb2,
/// 43: ^bb3
/// ]
/// -> br ^bb2
static LogicalResult simplifyConstSwitchValue(SwitchOp op,
PatternRewriter &rewriter) {
APInt caseValue;
if (!matchPattern(op.flag(), m_ConstantInt(&caseValue)))
return failure();
foldSwitch(op, rewriter, caseValue);
return success();
}
/// switch %c_42 : i32, [
/// default: ^bb1,
/// 42: ^bb2,
/// ]
/// ^bb2:
/// br ^bb3
/// ->
/// switch %c_42 : i32, [
/// default: ^bb1,
/// 42: ^bb3,
/// ]
static LogicalResult simplifyPassThroughSwitch(SwitchOp op,
PatternRewriter &rewriter) {
SmallVector<Block *> newCaseDests;
SmallVector<ValueRange> newCaseOperands;
SmallVector<SmallVector<Value>> argStorage;
auto caseValues = op.case_values();
auto caseDests = op.caseDestinations();
bool requiresChange = false;
for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
Block *caseDest = caseDests[i];
ValueRange caseOperands = op.getCaseOperands(i);
argStorage.emplace_back();
if (succeeded(collapseBranch(caseDest, caseOperands, argStorage.back())))
requiresChange = true;
newCaseDests.push_back(caseDest);
newCaseOperands.push_back(caseOperands);
}
Block *defaultDest = op.defaultDestination();
ValueRange defaultOperands = op.defaultOperands();
argStorage.emplace_back();
if (succeeded(
collapseBranch(defaultDest, defaultOperands, argStorage.back())))
requiresChange = true;
if (!requiresChange)
return failure();
rewriter.replaceOpWithNewOp<SwitchOp>(op, op.flag(), defaultDest,
defaultOperands, caseValues.getValue(),
newCaseDests, newCaseOperands);
return success();
}
/// switch %flag : i32, [
/// default: ^bb1,
/// 42: ^bb2,
/// ]
/// ^bb2:
/// switch %flag : i32, [
/// default: ^bb3,
/// 42: ^bb4
/// ]
/// ->
/// switch %flag : i32, [
/// default: ^bb1,
/// 42: ^bb2,
/// ]
/// ^bb2:
/// br ^bb4
///
/// and
///
/// switch %flag : i32, [
/// default: ^bb1,
/// 42: ^bb2,
/// ]
/// ^bb2:
/// switch %flag : i32, [
/// default: ^bb3,
/// 43: ^bb4
/// ]
/// ->
/// switch %flag : i32, [
/// default: ^bb1,
/// 42: ^bb2,
/// ]
/// ^bb2:
/// br ^bb3
static LogicalResult
simplifySwitchFromSwitchOnSameCondition(SwitchOp op,
PatternRewriter &rewriter) {
// Check that we have a single distinct predecessor.
Block *currentBlock = op->getBlock();
Block *predecessor = currentBlock->getSinglePredecessor();
if (!predecessor)
return failure();
// Check that the predecessor terminates with a switch branch to this block
// and that it branches on the same condition and that this branch isn't the
// default destination.
auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
if (!predSwitch || op.flag() != predSwitch.flag() ||
predSwitch.defaultDestination() == currentBlock)
return failure();
// Fold this switch to an unconditional branch.
APInt caseValue;
bool isDefault = true;
SuccessorRange predDests = predSwitch.caseDestinations();
Optional<DenseIntElementsAttr> predCaseValues = predSwitch.case_values();
for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i) {
if (currentBlock == predDests[i]) {
caseValue = predCaseValues->getValue<APInt>(i);
isDefault = false;
break;
}
}
if (isDefault)
rewriter.replaceOpWithNewOp<BranchOp>(op, op.defaultDestination(),
op.defaultOperands());
else
foldSwitch(op, rewriter, caseValue);
return success();
}
/// switch %flag : i32, [
/// default: ^bb1,
/// 42: ^bb2
/// ]
/// ^bb1:
/// switch %flag : i32, [
/// default: ^bb3,
/// 42: ^bb4,
/// 43: ^bb5
/// ]
/// ->
/// switch %flag : i32, [
/// default: ^bb1,
/// 42: ^bb2,
/// ]
/// ^bb1:
/// switch %flag : i32, [
/// default: ^bb3,
/// 43: ^bb5
/// ]
static LogicalResult
simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op,
PatternRewriter &rewriter) {
// Check that we have a single distinct predecessor.
Block *currentBlock = op->getBlock();
Block *predecessor = currentBlock->getSinglePredecessor();
if (!predecessor)
return failure();
// Check that the predecessor terminates with a switch branch to this block
// and that it branches on the same condition and that this branch is the
// default destination.
auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
if (!predSwitch || op.flag() != predSwitch.flag() ||
predSwitch.defaultDestination() != currentBlock)
return failure();
// Delete case values that are not possible here.
DenseSet<APInt> caseValuesToRemove;
auto predDests = predSwitch.caseDestinations();
auto predCaseValues = predSwitch.case_values();
for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i)
if (currentBlock != predDests[i])
caseValuesToRemove.insert(predCaseValues->getValue<APInt>(i));
SmallVector<Block *> newCaseDestinations;
SmallVector<ValueRange> newCaseOperands;
SmallVector<APInt> newCaseValues;
bool requiresChange = false;
auto caseValues = op.case_values();
auto caseDests = op.caseDestinations();
for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
if (caseValuesToRemove.contains(caseValues->getValue<APInt>(i))) {
requiresChange = true;
continue;
}
newCaseDestinations.push_back(caseDests[i]);
newCaseOperands.push_back(op.getCaseOperands(i));
newCaseValues.push_back(caseValues->getValue<APInt>(i));
}
if (!requiresChange)
return failure();
rewriter.replaceOpWithNewOp<SwitchOp>(op, op.flag(), op.defaultDestination(),
op.defaultOperands(), newCaseValues,
newCaseDestinations, newCaseOperands);
return success();
}
void SwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add(&simplifySwitchWithOnlyDefault)
.add(&dropSwitchCasesThatMatchDefault)
.add(&simplifyConstSwitchValue)
.add(&simplifyPassThroughSwitch)
.add(&simplifySwitchFromSwitchOnSameCondition)
.add(&simplifySwitchFromDefaultSwitchOnSameCondition);
}
//===----------------------------------------------------------------------===//
// TruncateIOp
//===----------------------------------------------------------------------===//

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt %s -allow-unregistered-dialect -pass-pipeline='func(canonicalize)' -split-input-file | FileCheck %s
// RUN: mlir-opt %s -allow-unregistered-dialect -pass-pipeline='func(canonicalize)' -split-input-file | FileCheck --dump-input-context 20 %s
/// Test the folding of BranchOp.
@ -139,6 +139,268 @@ func @cond_br_pass_through_fail(%cond : i1) {
return
}
/// Test the folding of SwitchOp
// CHECK-LABEL: func @switch_only_default(
// CHECK-SAME: %[[FLAG:[a-zA-Z0-9_]+]]
// CHECK-SAME: %[[CASE_OPERAND_0:[a-zA-Z0-9_]+]]
func @switch_only_default(%flag : i32, %caseOperand0 : f32) {
// add predecessors for all blocks to avoid other canonicalizations.
"foo.pred"() [^bb1, ^bb2] : () -> ()
^bb1:
// CHECK-NOT: switch
// CHECK: br ^[[BB2:[a-zA-Z0-9_]+]](%[[CASE_OPERAND_0]]
switch %flag : i32, [
default: ^bb2(%caseOperand0 : f32)
]
// CHECK: ^[[BB2]]({{.*}}):
^bb2(%bb2Arg : f32):
// CHECK-NEXT: "foo.bb2Terminator"
"foo.bb2Terminator"(%bb2Arg) : (f32) -> ()
}
// CHECK-LABEL: func @switch_case_matching_default(
// CHECK-SAME: %[[FLAG:[a-zA-Z0-9_]+]]
// CHECK-SAME: %[[CASE_OPERAND_0:[a-zA-Z0-9_]+]]
// CHECK-SAME: %[[CASE_OPERAND_1:[a-zA-Z0-9_]+]]
func @switch_case_matching_default(%flag : i32, %caseOperand0 : f32, %caseOperand1 : f32) {
// add predecessors for all blocks to avoid other canonicalizations.
"foo.pred"() [^bb1, ^bb2, ^bb3] : () -> ()
^bb1:
// CHECK: switch %[[FLAG]]
// CHECK-NEXT: default: ^[[BB1:.+]](%[[CASE_OPERAND_0]] : f32)
// CHECK-NEXT: 10: ^[[BB2:.+]](%[[CASE_OPERAND_1]] : f32)
// CHECK-NEXT: ]
switch %flag : i32, [
default: ^bb2(%caseOperand0 : f32),
42: ^bb2(%caseOperand0 : f32),
10: ^bb3(%caseOperand1 : f32),
17: ^bb2(%caseOperand0 : f32)
]
^bb2(%bb2Arg : f32):
"foo.bb2Terminator"(%bb2Arg) : (f32) -> ()
^bb3(%bb3Arg : f32):
"foo.bb3Terminator"(%bb3Arg) : (f32) -> ()
}
// CHECK-LABEL: func @switch_on_const_no_match(
// CHECK-SAME: %[[CASE_OPERAND_0:[a-zA-Z0-9_]+]]
// CHECK-SAME: %[[CASE_OPERAND_1:[a-zA-Z0-9_]+]]
// CHECK-SAME: %[[CASE_OPERAND_2:[a-zA-Z0-9_]+]]
func @switch_on_const_no_match(%caseOperand0 : f32, %caseOperand1 : f32, %caseOperand2 : f32) {
// add predecessors for all blocks to avoid other canonicalizations.
"foo.pred"() [^bb1, ^bb2, ^bb3, ^bb4] : () -> ()
^bb1:
// CHECK-NOT: switch
// CHECK: br ^[[BB2:[a-zA-Z0-9_]+]](%[[CASE_OPERAND_0]]
%c0_i32 = constant 0 : i32
switch %c0_i32 : i32, [
default: ^bb2(%caseOperand0 : f32),
-1: ^bb3(%caseOperand1 : f32),
1: ^bb4(%caseOperand2 : f32)
]
// CHECK: ^[[BB2]]({{.*}}):
// CHECK-NEXT: "foo.bb2Terminator"
^bb2(%bb2Arg : f32):
"foo.bb2Terminator"(%bb2Arg) : (f32) -> ()
^bb3(%bb3Arg : f32):
"foo.bb3Terminator"(%bb3Arg) : (f32) -> ()
^bb4(%bb4Arg : f32):
"foo.bb4Terminator"(%bb4Arg) : (f32) -> ()
}
// CHECK-LABEL: func @switch_on_const_with_match(
// CHECK-SAME: %[[CASE_OPERAND_0:[a-zA-Z0-9_]+]]
// CHECK-SAME: %[[CASE_OPERAND_1:[a-zA-Z0-9_]+]]
// CHECK-SAME: %[[CASE_OPERAND_2:[a-zA-Z0-9_]+]]
func @switch_on_const_with_match(%caseOperand0 : f32, %caseOperand1 : f32, %caseOperand2 : f32) {
// add predecessors for all blocks to avoid other canonicalizations.
"foo.pred"() [^bb1, ^bb2, ^bb3, ^bb4] : () -> ()
^bb1:
// CHECK-NOT: switch
// CHECK: br ^[[BB4:[a-zA-Z0-9_]+]](%[[CASE_OPERAND_2]]
%c0_i32 = constant 1 : i32
switch %c0_i32 : i32, [
default: ^bb2(%caseOperand0 : f32),
-1: ^bb3(%caseOperand1 : f32),
1: ^bb4(%caseOperand2 : f32)
]
^bb2(%bb2Arg : f32):
"foo.bb2Terminator"(%bb2Arg) : (f32) -> ()
^bb3(%bb3Arg : f32):
"foo.bb3Terminator"(%bb3Arg) : (f32) -> ()
// CHECK: ^[[BB4]]({{.*}}):
// CHECK-NEXT: "foo.bb4Terminator"
^bb4(%bb4Arg : f32):
"foo.bb4Terminator"(%bb4Arg) : (f32) -> ()
}
// CHECK-LABEL: func @switch_passthrough(
// CHECK-SAME: %[[FLAG:[a-zA-Z0-9_]+]]
// CHECK-SAME: %[[CASE_OPERAND_0:[a-zA-Z0-9_]+]]
// CHECK-SAME: %[[CASE_OPERAND_1:[a-zA-Z0-9_]+]]
// CHECK-SAME: %[[CASE_OPERAND_2:[a-zA-Z0-9_]+]]
// CHECK-SAME: %[[CASE_OPERAND_3:[a-zA-Z0-9_]+]]
func @switch_passthrough(%flag : i32,
%caseOperand0 : f32,
%caseOperand1 : f32,
%caseOperand2 : f32,
%caseOperand3 : f32) {
// add predecessors for all blocks to avoid other canonicalizations.
"foo.pred"() [^bb1, ^bb2, ^bb3, ^bb4, ^bb5, ^bb6] : () -> ()
^bb1:
// CHECK: switch %[[FLAG]]
// CHECK-NEXT: default: ^[[BB5:[a-zA-Z0-9_]+]](%[[CASE_OPERAND_0]]
// CHECK-NEXT: 43: ^[[BB6:[a-zA-Z0-9_]+]](%[[CASE_OPERAND_1]]
// CHECK-NEXT: 44: ^[[BB4:[a-zA-Z0-9_]+]](%[[CASE_OPERAND_2]]
// CHECK-NEXT: ]
switch %flag : i32, [
default: ^bb2(%caseOperand0 : f32),
43: ^bb3(%caseOperand1 : f32),
44: ^bb4(%caseOperand2 : f32)
]
^bb2(%bb2Arg : f32):
br ^bb5(%bb2Arg : f32)
^bb3(%bb3Arg : f32):
br ^bb6(%bb3Arg : f32)
^bb4(%bb4Arg : f32):
"foo.bb4Terminator"(%bb4Arg) : (f32) -> ()
// CHECK: ^[[BB5]]({{.*}}):
// CHECK-NEXT: "foo.bb5Terminator"
^bb5(%bb5Arg : f32):
"foo.bb5Terminator"(%bb5Arg) : (f32) -> ()
// CHECK: ^[[BB6]]({{.*}}):
// CHECK-NEXT: "foo.bb6Terminator"
^bb6(%bb6Arg : f32):
"foo.bb6Terminator"(%bb6Arg) : (f32) -> ()
}
// CHECK-LABEL: func @switch_from_switch_with_same_value_with_match(
// CHECK-SAME: %[[FLAG:[a-zA-Z0-9_]+]]
// CHECK-SAME: %[[CASE_OPERAND_0:[a-zA-Z0-9_]+]]
// CHECK-SAME: %[[CASE_OPERAND_1:[a-zA-Z0-9_]+]]
func @switch_from_switch_with_same_value_with_match(%flag : i32, %caseOperand0 : f32, %caseOperand1 : f32) {
// add predecessors for all blocks except ^bb3 to avoid other canonicalizations.
"foo.pred"() [^bb1, ^bb2, ^bb4, ^bb5] : () -> ()
^bb1:
// CHECK: switch %[[FLAG]]
switch %flag : i32, [
default: ^bb2,
42: ^bb3
]
^bb2:
"foo.bb2Terminator"() : () -> ()
^bb3:
// prevent this block from being simplified away
"foo.op"() : () -> ()
// CHECK-NOT: switch %[[FLAG]]
// CHECK: br ^[[BB5:[a-zA-Z0-9_]+]](%[[CASE_OPERAND_1]]
switch %flag : i32, [
default: ^bb4(%caseOperand0 : f32),
42: ^bb5(%caseOperand1 : f32)
]
^bb4(%bb4Arg : f32):
"foo.bb4Terminator"(%bb4Arg) : (f32) -> ()
// CHECK: ^[[BB5]]({{.*}}):
// CHECK-NEXT: "foo.bb5Terminator"
^bb5(%bb5Arg : f32):
"foo.bb5Terminator"(%bb5Arg) : (f32) -> ()
}
// CHECK-LABEL: func @switch_from_switch_with_same_value_no_match(
// CHECK-SAME: %[[FLAG:[a-zA-Z0-9_]+]]
// CHECK-SAME: %[[CASE_OPERAND_0:[a-zA-Z0-9_]+]]
// CHECK-SAME: %[[CASE_OPERAND_1:[a-zA-Z0-9_]+]]
// CHECK-SAME: %[[CASE_OPERAND_2:[a-zA-Z0-9_]+]]
func @switch_from_switch_with_same_value_no_match(%flag : i32, %caseOperand0 : f32, %caseOperand1 : f32, %caseOperand2 : f32) {
// add predecessors for all blocks except ^bb3 to avoid other canonicalizations.
"foo.pred"() [^bb1, ^bb2, ^bb4, ^bb5, ^bb6] : () -> ()
^bb1:
// CHECK: switch %[[FLAG]]
switch %flag : i32, [
default: ^bb2,
42: ^bb3
]
^bb2:
"foo.bb2Terminator"() : () -> ()
^bb3:
"foo.op"() : () -> ()
// CHECK-NOT: switch %[[FLAG]]
// CHECK: br ^[[BB4:[a-zA-Z0-9_]+]](%[[CASE_OPERAND_0]]
switch %flag : i32, [
default: ^bb4(%caseOperand0 : f32),
0: ^bb5(%caseOperand1 : f32),
43: ^bb6(%caseOperand2 : f32)
]
// CHECK: ^[[BB4]]({{.*}})
// CHECK-NEXT: "foo.bb4Terminator"
^bb4(%bb4Arg : f32):
"foo.bb4Terminator"(%bb4Arg) : (f32) -> ()
^bb5(%bb5Arg : f32):
"foo.bb5Terminator"(%bb5Arg) : (f32) -> ()
^bb6(%bb6Arg : f32):
"foo.bb6Terminator"(%bb6Arg) : (f32) -> ()
}
// CHECK-LABEL: func @switch_from_switch_default_with_same_value(
// CHECK-SAME: %[[FLAG:[a-zA-Z0-9_]+]]
// CHECK-SAME: %[[CASE_OPERAND_0:[a-zA-Z0-9_]+]]
// CHECK-SAME: %[[CASE_OPERAND_1:[a-zA-Z0-9_]+]]
// CHECK-SAME: %[[CASE_OPERAND_2:[a-zA-Z0-9_]+]]
func @switch_from_switch_default_with_same_value(%flag : i32, %caseOperand0 : f32, %caseOperand1 : f32, %caseOperand2 : f32) {
// add predecessors for all blocks except ^bb3 to avoid other canonicalizations.
"foo.pred"() [^bb1, ^bb2, ^bb4, ^bb5, ^bb6] : () -> ()
^bb1:
// CHECK: switch %[[FLAG]]
switch %flag : i32, [
default: ^bb3,
42: ^bb2
]
^bb2:
"foo.bb2Terminator"() : () -> ()
^bb3:
"foo.op"() : () -> ()
// CHECK: switch %[[FLAG]]
// CHECK-NEXT: default: ^[[BB4:[a-zA-Z0-9_]+]](%[[CASE_OPERAND_0]]
// CHECK-NEXT: 43: ^[[BB6:[a-zA-Z0-9_]+]](%[[CASE_OPERAND_2]]
// CHECK-NOT: 42
switch %flag : i32, [
default: ^bb4(%caseOperand0 : f32),
42: ^bb5(%caseOperand1 : f32),
43: ^bb6(%caseOperand2 : f32)
]
// CHECK: ^[[BB4]]({{.*}}):
// CHECK-NEXT: "foo.bb4Terminator"
^bb4(%bb4Arg : f32):
"foo.bb4Terminator"(%bb4Arg) : (f32) -> ()
^bb5(%bb5Arg : f32):
"foo.bb5Terminator"(%bb5Arg) : (f32) -> ()
// CHECK: ^[[BB6]]({{.*}}):
// CHECK-NEXT: "foo.bb6Terminator"
^bb6(%bb6Arg : f32):
"foo.bb6Terminator"(%bb6Arg) : (f32) -> ()
}
/// Test folding conditional branches that are successors of conditional
/// branches with the same condition.

View File

@ -96,3 +96,35 @@ func @read_global_memref() {
%1 = memref.tensor_load %0 : memref<2xf32>
return
}
// CHECK-LABEL: func @switch(
func @switch(%flag : i32, %caseOperand : i32) {
switch %flag : i32, [
default: ^bb1(%caseOperand : i32),
42: ^bb2(%caseOperand : i32),
43: ^bb3(%caseOperand : i32)
]
^bb1(%bb1arg : i32):
return
^bb2(%bb2arg : i32):
return
^bb3(%bb3arg : i32):
return
}
// CHECK-LABEL: func @switch_i64(
func @switch_i64(%flag : i64, %caseOperand : i32) {
switch %flag : i64, [
default: ^bb1(%caseOperand : i32),
42: ^bb2(%caseOperand : i32),
43: ^bb3(%caseOperand : i32)
]
^bb1(%bb1arg : i32):
return
^bb2(%bb2arg : i32):
return
^bb3(%bb3arg : i32):
return
}

View File

@ -0,0 +1,69 @@
// RUN: mlir-opt -verify-diagnostics -split-input-file %s
func @switch_missing_case_value(%flag : i32, %caseOperand : i32) {
switch %flag : i32, [
default: ^bb1(%caseOperand : i32),
45: ^bb2(%caseOperand : i32),
// expected-error@+1 {{expected integer value}}
: ^bb3(%caseOperand : i32)
]
^bb1(%bb1arg : i32):
return
^bb2(%bb2arg : i32):
return
^bb3(%bb3arg : i32):
return
}
// -----
func @switch_wrong_type_case_value(%flag : i32, %caseOperand : i32) {
switch %flag : i32, [
default: ^bb1(%caseOperand : i32),
// expected-error@+1 {{expected integer value}}
"hello": ^bb2(%caseOperand : i32)
]
^bb1(%bb1arg : i32):
return
^bb2(%bb2arg : i32):
return
^bb3(%bb3arg : i32):
return
}
// -----
func @switch_missing_comma(%flag : i32, %caseOperand : i32) {
switch %flag : i32, [
default: ^bb1(%caseOperand : i32),
45: ^bb2(%caseOperand : i32)
// expected-error@+1 {{expected ']'}}
43: ^bb3(%caseOperand : i32)
]
^bb1(%bb1arg : i32):
return
^bb2(%bb2arg : i32):
return
^bb3(%bb3arg : i32):
return
}
// -----
func @switch_missing_default(%flag : i32, %caseOperand : i32) {
switch %flag : i32, [
// expected-error@+1 {{expected 'default'}}
45: ^bb2(%caseOperand : i32)
43: ^bb3(%caseOperand : i32)
]
^bb1(%bb1arg : i32):
return
^bb2(%bb2arg : i32):
return
^bb3(%bb3arg : i32):
return
}