[mlir] Split out a new ControlFlow dialect from Standard

This dialect is intended to model lower level/branch based control-flow constructs. The initial set
of operations are: AssertOp, BranchOp, CondBranchOp, SwitchOp; all split out from the current
standard dialect.

See https://discourse.llvm.org/t/standard-dialect-the-final-chapter/6061

Differential Revision: https://reviews.llvm.org/D118966
This commit is contained in:
River Riddle 2022-02-03 20:59:43 -08:00
parent edca177cbe
commit ace01605e0
239 changed files with 3027 additions and 2585 deletions

View File

@ -27,8 +27,8 @@ namespace fir::support {
#define FLANG_NONCODEGEN_DIALECT_LIST \
mlir::AffineDialect, FIROpsDialect, mlir::acc::OpenACCDialect, \
mlir::omp::OpenMPDialect, mlir::scf::SCFDialect, \
mlir::arith::ArithmeticDialect, mlir::StandardOpsDialect, \
mlir::vector::VectorDialect
mlir::arith::ArithmeticDialect, mlir::cf::ControlFlowDialect, \
mlir::StandardOpsDialect, mlir::vector::VectorDialect
// The definitive list of dialects used by flang.
#define FLANG_DIALECT_LIST \

View File

@ -9,7 +9,7 @@
/// This file defines some shared command-line options that can be used when
/// debugging the test tools. This file must be included into the tool.
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
@ -139,7 +139,7 @@ inline void createDefaultFIROptimizerPassPipeline(mlir::PassManager &pm) {
// convert control flow to CFG form
fir::addCfgConversionPass(pm);
pm.addPass(mlir::createLowerToCFGPass());
pm.addPass(mlir::createConvertSCFToCFPass());
pm.addPass(mlir::createCanonicalizerPass(config));
}

View File

@ -32,7 +32,7 @@ add_flang_library(FortranLower
FortranSemantics
MLIRAffineToStandard
MLIRLLVMIR
MLIRSCFToStandard
MLIRSCFToControlFlow
MLIRStandard
LINK_COMPONENTS

View File

@ -18,6 +18,7 @@
#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/Support/TypeCode.h"
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/IR/BuiltinTypes.h"
@ -3293,6 +3294,8 @@ public:
mlir::populateStdToLLVMConversionPatterns(typeConverter, pattern);
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
pattern);
mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter,
pattern);
mlir::ConversionTarget target{*context};
target.addLegalDialect<mlir::LLVM::LLVMDialect>();

View File

@ -13,6 +13,7 @@
#include "flang/Optimizer/Dialect/FIRDialect.h"
#include "flang/Optimizer/Support/FIRContext.h"
#include "flang/Optimizer/Transforms/Passes.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/Debug.h"
@ -332,9 +333,9 @@ ArrayCopyAnalysis::arrayAccesses(ArrayLoadOp load) {
<< "add modify {" << *owner << "} to array value set\n");
accesses.push_back(owner);
appendToQueue(update.getResult(1));
} else if (auto br = mlir::dyn_cast<mlir::BranchOp>(owner)) {
} else if (auto br = mlir::dyn_cast<mlir::cf::BranchOp>(owner)) {
branchOp(br.getDest(), br.getDestOperands());
} else if (auto br = mlir::dyn_cast<mlir::CondBranchOp>(owner)) {
} else if (auto br = mlir::dyn_cast<mlir::cf::CondBranchOp>(owner)) {
branchOp(br.getTrueDest(), br.getTrueOperands());
branchOp(br.getFalseDest(), br.getFalseOperands());
} else if (mlir::isa<ArrayMergeStoreOp>(owner)) {
@ -789,9 +790,9 @@ public:
patterns1.insert<ArrayUpdateConversion>(context, analysis, useMap);
patterns1.insert<ArrayModifyConversion>(context, analysis, useMap);
mlir::ConversionTarget target(*context);
target.addLegalDialect<FIROpsDialect, mlir::scf::SCFDialect,
mlir::arith::ArithmeticDialect,
mlir::StandardOpsDialect>();
target.addLegalDialect<
FIROpsDialect, mlir::scf::SCFDialect, mlir::arith::ArithmeticDialect,
mlir::cf::ControlFlowDialect, mlir::StandardOpsDialect>();
target.addIllegalOp<ArrayFetchOp, ArrayUpdateOp, ArrayModifyOp>();
// Rewrite the array fetch and array update ops.
if (mlir::failed(

View File

@ -11,6 +11,7 @@
#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/Transforms/Passes.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
@ -84,7 +85,7 @@ public:
loopOperands.append(operands.begin(), operands.end());
loopOperands.push_back(iters);
rewriter.create<mlir::BranchOp>(loc, conditionalBlock, loopOperands);
rewriter.create<mlir::cf::BranchOp>(loc, conditionalBlock, loopOperands);
// Last loop block
auto *terminator = lastBlock->getTerminator();
@ -105,7 +106,7 @@ public:
: terminator->operand_begin();
loopCarried.append(begin, terminator->operand_end());
loopCarried.push_back(itersMinusOne);
rewriter.create<mlir::BranchOp>(loc, conditionalBlock, loopCarried);
rewriter.create<mlir::cf::BranchOp>(loc, conditionalBlock, loopCarried);
rewriter.eraseOp(terminator);
// Conditional block
@ -114,9 +115,9 @@ public:
auto comparison = rewriter.create<mlir::arith::CmpIOp>(
loc, arith::CmpIPredicate::sgt, itersLeft, zero);
rewriter.create<mlir::CondBranchOp>(loc, comparison, firstBlock,
llvm::ArrayRef<mlir::Value>(), endBlock,
llvm::ArrayRef<mlir::Value>());
rewriter.create<mlir::cf::CondBranchOp>(
loc, comparison, firstBlock, llvm::ArrayRef<mlir::Value>(), endBlock,
llvm::ArrayRef<mlir::Value>());
// The result of the loop operation is the values of the condition block
// arguments except the induction variable on the last iteration.
@ -155,7 +156,7 @@ public:
} else {
continueBlock =
rewriter.createBlock(remainingOpsBlock, ifOp.getResultTypes());
rewriter.create<mlir::BranchOp>(loc, remainingOpsBlock);
rewriter.create<mlir::cf::BranchOp>(loc, remainingOpsBlock);
}
// Move blocks from the "then" region to the region containing 'fir.if',
@ -165,7 +166,8 @@ public:
auto *ifOpTerminator = ifOpRegion.back().getTerminator();
auto ifOpTerminatorOperands = ifOpTerminator->getOperands();
rewriter.setInsertionPointToEnd(&ifOpRegion.back());
rewriter.create<mlir::BranchOp>(loc, continueBlock, ifOpTerminatorOperands);
rewriter.create<mlir::cf::BranchOp>(loc, continueBlock,
ifOpTerminatorOperands);
rewriter.eraseOp(ifOpTerminator);
rewriter.inlineRegionBefore(ifOpRegion, continueBlock);
@ -179,14 +181,14 @@ public:
auto *otherwiseTerm = otherwiseRegion.back().getTerminator();
auto otherwiseTermOperands = otherwiseTerm->getOperands();
rewriter.setInsertionPointToEnd(&otherwiseRegion.back());
rewriter.create<mlir::BranchOp>(loc, continueBlock,
otherwiseTermOperands);
rewriter.create<mlir::cf::BranchOp>(loc, continueBlock,
otherwiseTermOperands);
rewriter.eraseOp(otherwiseTerm);
rewriter.inlineRegionBefore(otherwiseRegion, continueBlock);
}
rewriter.setInsertionPointToEnd(condBlock);
rewriter.create<mlir::CondBranchOp>(
rewriter.create<mlir::cf::CondBranchOp>(
loc, ifOp.condition(), ifOpBlock, llvm::ArrayRef<mlir::Value>(),
otherwiseBlock, llvm::ArrayRef<mlir::Value>());
rewriter.replaceOp(ifOp, continueBlock->getArguments());
@ -241,7 +243,7 @@ public:
auto begin = whileOp.finalValue() ? std::next(terminator->operand_begin())
: terminator->operand_begin();
loopCarried.append(begin, terminator->operand_end());
rewriter.create<mlir::BranchOp>(loc, conditionBlock, loopCarried);
rewriter.create<mlir::cf::BranchOp>(loc, conditionBlock, loopCarried);
rewriter.eraseOp(terminator);
// Compute loop bounds before branching to the condition.
@ -256,7 +258,7 @@ public:
destOperands.push_back(lowerBound);
auto iterOperands = whileOp.getIterOperands();
destOperands.append(iterOperands.begin(), iterOperands.end());
rewriter.create<mlir::BranchOp>(loc, conditionBlock, destOperands);
rewriter.create<mlir::cf::BranchOp>(loc, conditionBlock, destOperands);
// With the body block done, we can fill in the condition block.
rewriter.setInsertionPointToEnd(conditionBlock);
@ -278,9 +280,9 @@ public:
// Remember to AND in the early-exit bool.
auto comparison =
rewriter.create<mlir::arith::AndIOp>(loc, iterateVar, cmp2);
rewriter.create<mlir::CondBranchOp>(loc, comparison, firstBodyBlock,
llvm::ArrayRef<mlir::Value>(), endBlock,
llvm::ArrayRef<mlir::Value>());
rewriter.create<mlir::cf::CondBranchOp>(
loc, comparison, firstBodyBlock, llvm::ArrayRef<mlir::Value>(),
endBlock, llvm::ArrayRef<mlir::Value>());
// The result of the loop operation is the values of the condition block
// arguments except the induction variable on the last iteration.
auto args = whileOp.finalValue()
@ -300,8 +302,8 @@ public:
patterns.insert<CfgLoopConv, CfgIfConv, CfgIterWhileConv>(
context, forceLoopToExecuteOnce);
mlir::ConversionTarget target(*context);
target.addLegalDialect<mlir::AffineDialect, FIROpsDialect,
mlir::StandardOpsDialect>();
target.addLegalDialect<mlir::AffineDialect, mlir::cf::ControlFlowDialect,
FIROpsDialect, mlir::StandardOpsDialect>();
// apply the patterns
target.addIllegalOp<ResultOp, DoLoopOp, IfOp, IterWhileOp>();

View File

@ -10,10 +10,10 @@ func @select_case_charachter(%arg0: !fir.char<2, 10>, %arg1: !fir.char<2, 10>, %
unit, ^bb3]
^bb1:
%c1_i32 = arith.constant 1 : i32
br ^bb3
cf.br ^bb3
^bb2:
%c2_i32 = arith.constant 2 : i32
br ^bb3
cf.br ^bb3
^bb3:
return
}

View File

@ -1175,23 +1175,23 @@ func @select_case_integer(%arg0: !fir.ref<i32>) -> i32 {
^bb1: // pred: ^bb0
%c1_i32_0 = arith.constant 1 : i32
fir.store %c1_i32_0 to %arg0 : !fir.ref<i32>
br ^bb6
cf.br ^bb6
^bb2: // pred: ^bb0
%c2_i32_1 = arith.constant 2 : i32
fir.store %c2_i32_1 to %arg0 : !fir.ref<i32>
br ^bb6
cf.br ^bb6
^bb3: // pred: ^bb0
%c0_i32 = arith.constant 0 : i32
fir.store %c0_i32 to %arg0 : !fir.ref<i32>
br ^bb6
cf.br ^bb6
^bb4: // pred: ^bb0
%c4_i32_2 = arith.constant 4 : i32
fir.store %c4_i32_2 to %arg0 : !fir.ref<i32>
br ^bb6
cf.br ^bb6
^bb5: // 3 preds: ^bb0, ^bb0, ^bb0
%c7_i32_3 = arith.constant 7 : i32
fir.store %c7_i32_3 to %arg0 : !fir.ref<i32>
br ^bb6
cf.br ^bb6
^bb6: // 5 preds: ^bb1, ^bb2, ^bb3, ^bb4, ^bb5
%3 = fir.load %arg0 : !fir.ref<i32>
return %3 : i32
@ -1275,10 +1275,10 @@ func @select_case_logical(%arg0: !fir.ref<!fir.logical<4>>) {
unit, ^bb3]
^bb1:
%c1_i32 = arith.constant 1 : i32
br ^bb3
cf.br ^bb3
^bb2:
%c2_i32 = arith.constant 2 : i32
br ^bb3
cf.br ^bb3
^bb3:
return
}

View File

@ -9,10 +9,10 @@ func @load_store_chain_removal(%arg0: !fir.ref<!fir.array<60xi32>>, %arg1: !fir.
%c1 = arith.constant 1 : index
%0 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFf1dcEi"}
%1 = fir.alloca !fir.array<60xi32> {bindc_name = "t1", uniq_name = "_QFf1dcEt1"}
br ^bb1(%c1, %c60 : index, index)
cf.br ^bb1(%c1, %c60 : index, index)
^bb1(%2: index, %3: index): // 2 preds: ^bb0, ^bb2
%4 = arith.cmpi sgt, %3, %c0 : index
cond_br %4, ^bb2, ^bb3
cf.cond_br %4, ^bb2, ^bb3
^bb2: // pred: ^bb1
%5 = fir.convert %2 : (index) -> i32
fir.store %5 to %0 : !fir.ref<i32>
@ -26,14 +26,14 @@ func @load_store_chain_removal(%arg0: !fir.ref<!fir.array<60xi32>>, %arg1: !fir.
fir.store %11 to %12 : !fir.ref<i32>
%13 = arith.addi %2, %c1 : index
%14 = arith.subi %3, %c1 : index
br ^bb1(%13, %14 : index, index)
cf.br ^bb1(%13, %14 : index, index)
^bb3: // pred: ^bb1
%15 = fir.convert %2 : (index) -> i32
fir.store %15 to %0 : !fir.ref<i32>
br ^bb4(%c1, %c60 : index, index)
cf.br ^bb4(%c1, %c60 : index, index)
^bb4(%16: index, %17: index): // 2 preds: ^bb3, ^bb5
%18 = arith.cmpi sgt, %17, %c0 : index
cond_br %18, ^bb5, ^bb6
cf.cond_br %18, ^bb5, ^bb6
^bb5: // pred: ^bb4
%19 = fir.convert %16 : (index) -> i32
fir.store %19 to %0 : !fir.ref<i32>
@ -49,7 +49,7 @@ func @load_store_chain_removal(%arg0: !fir.ref<!fir.array<60xi32>>, %arg1: !fir.
fir.store %27 to %28 : !fir.ref<i32>
%29 = arith.addi %16, %c1 : index
%30 = arith.subi %17, %c1 : index
br ^bb4(%29, %30 : index, index)
cf.br ^bb4(%29, %30 : index, index)
^bb6: // pred: ^bb4
%31 = fir.convert %16 : (index) -> i32
fir.store %31 to %0 : !fir.ref<i32>

View File

@ -13,7 +13,7 @@ FIRTransforms
FIRBuilder
${dialect_libs}
MLIRAffineToStandard
MLIRSCFToStandard
MLIRSCFToControlFlow
FortranCommon
FortranParser
FortranEvaluate

View File

@ -38,7 +38,6 @@
#include "flang/Semantics/semantics.h"
#include "flang/Semantics/unparse-with-symbols.h"
#include "flang/Version.inc"
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/MLIRContext.h"

View File

@ -18,7 +18,7 @@ target_link_libraries(fir-opt PRIVATE
MLIRTransforms
MLIRAffineToStandard
MLIRAnalysis
MLIRSCFToStandard
MLIRSCFToControlFlow
MLIRParser
MLIRStandardToLLVM
MLIRSupport

View File

@ -17,7 +17,7 @@ target_link_libraries(tco PRIVATE
MLIRTransforms
MLIRAffineToStandard
MLIRAnalysis
MLIRSCFToStandard
MLIRSCFToControlFlow
MLIRParser
MLIRStandardToLLVM
MLIRSupport

View File

@ -17,7 +17,6 @@
#include "flang/Optimizer/Support/InternalNames.h"
#include "flang/Optimizer/Support/KindMapping.h"
#include "flang/Optimizer/Transforms/Passes.h"
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/MLIRContext.h"

View File

@ -26,7 +26,7 @@ def setup_passes(mlir_module):
f"sparse-tensor-conversion,"
f"builtin.func"
f"(linalg-bufferize,convert-linalg-to-loops,convert-vector-to-scf),"
f"convert-scf-to-std,"
f"convert-scf-to-cf,"
f"func-bufferize,"
f"arith-bufferize,"
f"builtin.func(tensor-bufferize,finalizing-bufferize),"

View File

@ -41,12 +41,12 @@ Example for breaking the invariant:
```mlir
func @condBranch(%arg0: i1, %arg1: memref<2xf32>) {
%0 = memref.alloc() : memref<2xf32>
cond_br %arg0, ^bb1, ^bb2
cf.cond_br %arg0, ^bb1, ^bb2
^bb1:
br ^bb3()
cf.br ^bb3()
^bb2:
partial_write(%0, %0)
br ^bb3()
cf.br ^bb3()
^bb3():
test.copy(%0, %arg1) : (memref<2xf32>, memref<2xf32>) -> ()
return
@ -74,13 +74,13 @@ untracked allocations are mixed:
func @mixedAllocation(%arg0: i1) {
%0 = memref.alloca() : memref<2xf32> // aliases: %2
%1 = memref.alloc() : memref<2xf32> // aliases: %2
cond_br %arg0, ^bb1, ^bb2
cf.cond_br %arg0, ^bb1, ^bb2
^bb1:
use(%0)
br ^bb3(%0 : memref<2xf32>)
cf.br ^bb3(%0 : memref<2xf32>)
^bb2:
use(%1)
br ^bb3(%1 : memref<2xf32>)
cf.br ^bb3(%1 : memref<2xf32>)
^bb3(%2: memref<2xf32>):
...
}
@ -129,13 +129,13 @@ BufferHoisting pass:
```mlir
func @condBranch(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
cond_br %arg0, ^bb1, ^bb2
cf.cond_br %arg0, ^bb1, ^bb2
^bb1:
br ^bb3(%arg1 : memref<2xf32>)
cf.br ^bb3(%arg1 : memref<2xf32>)
^bb2:
%0 = memref.alloc() : memref<2xf32> // aliases: %1
use(%0)
br ^bb3(%0 : memref<2xf32>)
cf.br ^bb3(%0 : memref<2xf32>)
^bb3(%1: memref<2xf32>): // %1 could be %0 or %arg1
test.copy(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
return
@ -150,12 +150,12 @@ of code:
```mlir
func @condBranch(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
%0 = memref.alloc() : memref<2xf32> // moved to bb0
cond_br %arg0, ^bb1, ^bb2
cf.cond_br %arg0, ^bb1, ^bb2
^bb1:
br ^bb3(%arg1 : memref<2xf32>)
cf.br ^bb3(%arg1 : memref<2xf32>)
^bb2:
use(%0)
br ^bb3(%0 : memref<2xf32>)
cf.br ^bb3(%0 : memref<2xf32>)
^bb3(%1: memref<2xf32>):
test.copy(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
return
@ -175,14 +175,14 @@ func @condBranchDynamicType(
%arg1: memref<?xf32>,
%arg2: memref<?xf32>,
%arg3: index) {
cond_br %arg0, ^bb1, ^bb2(%arg3: index)
cf.cond_br %arg0, ^bb1, ^bb2(%arg3: index)
^bb1:
br ^bb3(%arg1 : memref<?xf32>)
cf.br ^bb3(%arg1 : memref<?xf32>)
^bb2(%0: index):
%1 = memref.alloc(%0) : memref<?xf32> // cannot be moved upwards to the data
// dependency to %0
use(%1)
br ^bb3(%1 : memref<?xf32>)
cf.br ^bb3(%1 : memref<?xf32>)
^bb3(%2: memref<?xf32>):
test.copy(%2, %arg2) : (memref<?xf32>, memref<?xf32>) -> ()
return
@ -201,14 +201,14 @@ allocations have already been placed:
```mlir
func @branch(%arg0: i1) {
%0 = memref.alloc() : memref<2xf32> // aliases: %2
cond_br %arg0, ^bb1, ^bb2
cf.cond_br %arg0, ^bb1, ^bb2
^bb1:
%1 = memref.alloc() : memref<2xf32> // resides here for demonstration purposes
// aliases: %2
br ^bb3(%1 : memref<2xf32>)
cf.br ^bb3(%1 : memref<2xf32>)
^bb2:
use(%0)
br ^bb3(%0 : memref<2xf32>)
cf.br ^bb3(%0 : memref<2xf32>)
^bb3(%2: memref<2xf32>):
return
@ -233,16 +233,16 @@ result:
```mlir
func @branch(%arg0: i1) {
%0 = memref.alloc() : memref<2xf32>
cond_br %arg0, ^bb1, ^bb2
cf.cond_br %arg0, ^bb1, ^bb2
^bb1:
%1 = memref.alloc() : memref<2xf32>
%3 = bufferization.clone %1 : (memref<2xf32>) -> (memref<2xf32>)
memref.dealloc %1 : memref<2xf32> // %1 can be safely freed here
br ^bb3(%3 : memref<2xf32>)
cf.br ^bb3(%3 : memref<2xf32>)
^bb2:
use(%0)
%4 = bufferization.clone %0 : (memref<2xf32>) -> (memref<2xf32>)
br ^bb3(%4 : memref<2xf32>)
cf.br ^bb3(%4 : memref<2xf32>)
^bb3(%2: memref<2xf32>):
memref.dealloc %2 : memref<2xf32> // free temp buffer %2
@ -273,23 +273,23 @@ func @condBranchDynamicTypeNested(
%arg1: memref<?xf32>, // aliases: %3, %4
%arg2: memref<?xf32>,
%arg3: index) {
cond_br %arg0, ^bb1, ^bb2(%arg3: index)
cf.cond_br %arg0, ^bb1, ^bb2(%arg3: index)
^bb1:
br ^bb6(%arg1 : memref<?xf32>)
cf.br ^bb6(%arg1 : memref<?xf32>)
^bb2(%0: index):
%1 = memref.alloc(%0) : memref<?xf32> // cannot be moved upwards due to the data
// dependency to %0
// aliases: %2, %3, %4
use(%1)
cond_br %arg0, ^bb3, ^bb4
cf.cond_br %arg0, ^bb3, ^bb4
^bb3:
br ^bb5(%1 : memref<?xf32>)
cf.br ^bb5(%1 : memref<?xf32>)
^bb4:
br ^bb5(%1 : memref<?xf32>)
cf.br ^bb5(%1 : memref<?xf32>)
^bb5(%2: memref<?xf32>): // non-crit. alias of %1, since %1 dominates %2
br ^bb6(%2 : memref<?xf32>)
cf.br ^bb6(%2 : memref<?xf32>)
^bb6(%3: memref<?xf32>): // crit. alias of %arg1 and %2 (in other words %1)
br ^bb7(%3 : memref<?xf32>)
cf.br ^bb7(%3 : memref<?xf32>)
^bb7(%4: memref<?xf32>): // non-crit. alias of %3, since %3 dominates %4
test.copy(%4, %arg2) : (memref<?xf32>, memref<?xf32>) -> ()
return
@ -306,25 +306,25 @@ func @condBranchDynamicTypeNested(
%arg1: memref<?xf32>,
%arg2: memref<?xf32>,
%arg3: index) {
cond_br %arg0, ^bb1, ^bb2(%arg3 : index)
cf.cond_br %arg0, ^bb1, ^bb2(%arg3 : index)
^bb1:
// temp buffer required due to alias %3
%5 = bufferization.clone %arg1 : (memref<?xf32>) -> (memref<?xf32>)
br ^bb6(%5 : memref<?xf32>)
cf.br ^bb6(%5 : memref<?xf32>)
^bb2(%0: index):
%1 = memref.alloc(%0) : memref<?xf32>
use(%1)
cond_br %arg0, ^bb3, ^bb4
cf.cond_br %arg0, ^bb3, ^bb4
^bb3:
br ^bb5(%1 : memref<?xf32>)
cf.br ^bb5(%1 : memref<?xf32>)
^bb4:
br ^bb5(%1 : memref<?xf32>)
cf.br ^bb5(%1 : memref<?xf32>)
^bb5(%2: memref<?xf32>):
%6 = bufferization.clone %1 : (memref<?xf32>) -> (memref<?xf32>)
memref.dealloc %1 : memref<?xf32>
br ^bb6(%6 : memref<?xf32>)
cf.br ^bb6(%6 : memref<?xf32>)
^bb6(%3: memref<?xf32>):
br ^bb7(%3 : memref<?xf32>)
cf.br ^bb7(%3 : memref<?xf32>)
^bb7(%4: memref<?xf32>):
test.copy(%4, %arg2) : (memref<?xf32>, memref<?xf32>) -> ()
memref.dealloc %3 : memref<?xf32> // free %3, since %4 is a non-crit. alias of %3

View File

@ -295,7 +295,7 @@ A few examples are shown below:
```mlir
// Expect an error on the same line.
func @bad_branch() {
br ^missing // expected-error {{reference to an undefined block}}
cf.br ^missing // expected-error {{reference to an undefined block}}
}
// Expect an error on an adjacent line.

View File

@ -114,8 +114,8 @@ struct MyTarget : public ConversionTarget {
/// All operations within the GPU dialect are illegal.
addIllegalDialect<GPUDialect>();
/// Mark `std.br` and `std.cond_br` as illegal.
addIllegalOp<BranchOp, CondBranchOp>();
/// Mark `cf.br` and `cf.cond_br` as illegal.
addIllegalOp<cf::BranchOp, cf::CondBranchOp>();
}
/// Implement the default legalization handler to handle operations marked as

View File

@ -23,10 +23,11 @@ argument `-declare-variables-at-top`.
Besides operations part of the EmitC dialect, the Cpp targets supports
translating the following operations:
* 'cf' Dialect
* `cf.br`
* `cf.cond_br`
* 'std' Dialect
* `std.br`
* `std.call`
* `std.cond_br`
* `std.constant`
* `std.return`
* 'scf' Dialect

View File

@ -391,21 +391,21 @@ arguments:
```mlir
func @simple(i64, i1) -> i64 {
^bb0(%a: i64, %cond: i1): // Code dominated by ^bb0 may refer to %a
cond_br %cond, ^bb1, ^bb2
cf.cond_br %cond, ^bb1, ^bb2
^bb1:
br ^bb3(%a: i64) // Branch passes %a as the argument
cf.br ^bb3(%a: i64) // Branch passes %a as the argument
^bb2:
%b = arith.addi %a, %a : i64
br ^bb3(%b: i64) // Branch passes %b as the argument
cf.br ^bb3(%b: i64) // Branch passes %b as the argument
// ^bb3 receives an argument, named %c, from predecessors
// and passes it on to bb4 along with %a. %a is referenced
// directly from its defining operation and is not passed through
// an argument of ^bb3.
^bb3(%c: i64):
br ^bb4(%c, %a : i64, i64)
cf.br ^bb4(%c, %a : i64, i64)
^bb4(%d : i64, %e : i64):
%0 = arith.addi %d, %e : i64
@ -525,12 +525,12 @@ Example:
```mlir
func @accelerator_compute(i64, i1) -> i64 { // An SSACFG region
^bb0(%a: i64, %cond: i1): // Code dominated by ^bb0 may refer to %a
cond_br %cond, ^bb1, ^bb2
cf.cond_br %cond, ^bb1, ^bb2
^bb1:
// This def for %value does not dominate ^bb2
%value = "op.convert"(%a) : (i64) -> i64
br ^bb3(%a: i64) // Branch passes %a as the argument
cf.br ^bb3(%a: i64) // Branch passes %a as the argument
^bb2:
accelerator.launch() { // An SSACFG region

View File

@ -356,24 +356,24 @@ Example output is shown below:
```
//===-------------------------------------------===//
Processing operation : 'std.cond_br'(0x60f000001120) {
"std.cond_br"(%arg0)[^bb2, ^bb2] {operand_segment_sizes = dense<[1, 0, 0]> : vector<3xi32>} : (i1) -> ()
Processing operation : 'cf.cond_br'(0x60f000001120) {
"cf.cond_br"(%arg0)[^bb2, ^bb2] {operand_segment_sizes = dense<[1, 0, 0]> : vector<3xi32>} : (i1) -> ()
* Pattern SimplifyConstCondBranchPred : 'std.cond_br -> ()' {
* Pattern SimplifyConstCondBranchPred : 'cf.cond_br -> ()' {
} -> failure : pattern failed to match
* Pattern SimplifyCondBranchIdenticalSuccessors : 'std.cond_br -> ()' {
** Insert : 'std.br'(0x60b000003690)
** Replace : 'std.cond_br'(0x60f000001120)
* Pattern SimplifyCondBranchIdenticalSuccessors : 'cf.cond_br -> ()' {
** Insert : 'cf.br'(0x60b000003690)
** Replace : 'cf.cond_br'(0x60f000001120)
} -> success : pattern applied successfully
} -> success : pattern matched
//===-------------------------------------------===//
```
This output is describing the processing of a `std.cond_br` operation. We first
This output is describing the processing of a `cf.cond_br` operation. We first
try to apply the `SimplifyConstCondBranchPred`, which fails. From there, another
pattern (`SimplifyCondBranchIdenticalSuccessors`) is applied that matches the
`std.cond_br` and replaces it with a `std.br`.
`cf.cond_br` and replaces it with a `cf.br`.
## Debugging

View File

@ -560,24 +560,24 @@ func @search(%A: memref<?x?xi32>, %S: <?xi32>, %key : i32) {
func @search_body(%A: memref<?x?xi32>, %S: memref<?xi32>, %key: i32, %i : i32) {
%nj = memref.dim %A, 1 : memref<?x?xi32>
br ^bb1(0)
cf.br ^bb1(0)
^bb1(%j: i32)
%p1 = arith.cmpi "lt", %j, %nj : i32
cond_br %p1, ^bb2, ^bb5
cf.cond_br %p1, ^bb2, ^bb5
^bb2:
%v = affine.load %A[%i, %j] : memref<?x?xi32>
%p2 = arith.cmpi "eq", %v, %key : i32
cond_br %p2, ^bb3(%j), ^bb4
cf.cond_br %p2, ^bb3(%j), ^bb4
^bb3(%j: i32)
affine.store %j, %S[%i] : memref<?xi32>
br ^bb5
cf.br ^bb5
^bb4:
%jinc = arith.addi %j, 1 : i32
br ^bb1(%jinc)
cf.br ^bb1(%jinc)
^bb5:
return

View File

@ -94,10 +94,11 @@ multiple stages by relying on
```c++
mlir::RewritePatternSet patterns(&getContext());
mlir::populateAffineToStdConversionPatterns(patterns, &getContext());
mlir::populateLoopToStdConversionPatterns(patterns, &getContext());
mlir::cf::populateSCFToControlFlowConversionPatterns(patterns, &getContext());
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
patterns);
mlir::populateStdToLLVMConversionPatterns(typeConverter, patterns);
mlir::cf::populateControlFlowToLLVMConversionPatterns(patterns, &getContext());
// The only remaining operation, to lower from the `toy` dialect, is the
// PrintOp.
@ -207,7 +208,7 @@ define void @main() {
%109 = memref.load double, double* %108
%110 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double %109)
%111 = add i64 %100, 1
br label %99
cf.br label %99
...

View File

@ -361,7 +361,7 @@
</tspan></tspan><tspan
x="73.476562"
y="88.293896"><tspan
style="font-size:5.64444px">br bb3(%0)</tspan></tspan></text>
style="font-size:5.64444px">cf.br bb3(%0)</tspan></tspan></text>
<text
xml:space="preserve"
id="text1894"

Before

Width:  |  Height:  |  Size: 15 KiB

After

Width:  |  Height:  |  Size: 15 KiB

View File

@ -362,7 +362,7 @@
</tspan></tspan><tspan
x="73.476562"
y="88.293896"><tspan
style="font-size:5.64444px">br bb3(%0)</tspan></tspan></text>
style="font-size:5.64444px">cf.br bb3(%0)</tspan></tspan></text>
<text
xml:space="preserve"
id="text1894"

Before

Width:  |  Height:  |  Size: 15 KiB

After

Width:  |  Height:  |  Size: 15 KiB

View File

@ -26,10 +26,11 @@
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
@ -200,10 +201,11 @@ void ToyToLLVMLoweringPass::runOnOperation() {
// set of legal ones.
RewritePatternSet patterns(&getContext());
populateAffineToStdConversionPatterns(patterns);
populateLoopToStdConversionPatterns(patterns);
populateSCFToControlFlowConversionPatterns(patterns);
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
patterns);
populateMemRefToLLVMConversionPatterns(typeConverter, patterns);
cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns);
populateStdToLLVMConversionPatterns(typeConverter, patterns);
// The only remaining operation to lower from the `toy` dialect, is the

View File

@ -26,10 +26,11 @@
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
@ -200,10 +201,11 @@ void ToyToLLVMLoweringPass::runOnOperation() {
// set of legal ones.
RewritePatternSet patterns(&getContext());
populateAffineToStdConversionPatterns(patterns);
populateLoopToStdConversionPatterns(patterns);
populateSCFToControlFlowConversionPatterns(patterns);
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
patterns);
populateMemRefToLLVMConversionPatterns(typeConverter, patterns);
cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns);
populateStdToLLVMConversionPatterns(typeConverter, patterns);
// The only remaining operation to lower from the `toy` dialect, is the

View File

@ -0,0 +1,35 @@
//===- ControlFlowToLLVM.h - ControlFlow to LLVM -----------*- C++ ------*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Define conversions from the ControlFlow dialect to the LLVM IR dialect.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_CONVERSION_CONTROLFLOWTOLLVM_CONTROLFLOWTOLLVM_H
#define MLIR_CONVERSION_CONTROLFLOWTOLLVM_CONTROLFLOWTOLLVM_H
#include <memory>
namespace mlir {
class LLVMTypeConverter;
class RewritePatternSet;
class Pass;
namespace cf {
/// Collect the patterns to convert from the ControlFlow dialect to LLVM. The
/// conversion patterns capture the LLVMTypeConverter by reference meaning the
/// references have to remain alive during the entire pattern lifetime.
void populateControlFlowToLLVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns);
/// Creates a pass to convert the ControlFlow dialect into the LLVMIR dialect.
std::unique_ptr<Pass> createConvertControlFlowToLLVMPass();
} // namespace cf
} // namespace mlir
#endif // MLIR_CONVERSION_CONTROLFLOWTOLLVM_CONTROLFLOWTOLLVM_H

View File

@ -0,0 +1,28 @@
//===- ControlFlowToSPIRV.h - CF to SPIR-V Patterns --------*- C++ ------*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Provides patterns to convert ControlFlow dialect to SPIR-V dialect.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_CONVERSION_CONTROLFLOWTOSPIRV_CONTROLFLOWTOSPIRV_H
#define MLIR_CONVERSION_CONTROLFLOWTOSPIRV_CONTROLFLOWTOSPIRV_H
namespace mlir {
class RewritePatternSet;
class SPIRVTypeConverter;
namespace cf {
/// Appends to a pattern list additional patterns for translating ControlFLow
/// ops to SPIR-V ops.
void populateControlFlowToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns);
} // namespace cf
} // namespace mlir
#endif // MLIR_CONVERSION_CONTROLFLOWTOSPIRV_CONTROLFLOWTOSPIRV_H

View File

@ -17,6 +17,8 @@
#include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h"
#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
#include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
#include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h"
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
@ -35,10 +37,10 @@
#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
#include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h"
#include "mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h"
#include "mlir/Conversion/SCFToSPIRV/SCFToSPIRVPass.h"
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
#include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.h"
#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"

View File

@ -181,6 +181,28 @@ def ConvertComplexToStandard : Pass<"convert-complex-to-standard", "FuncOp"> {
let dependentDialects = ["math::MathDialect"];
}
//===----------------------------------------------------------------------===//
// ControlFlowToLLVM
//===----------------------------------------------------------------------===//
def ConvertControlFlowToLLVM : Pass<"convert-cf-to-llvm", "ModuleOp"> {
let summary = "Convert ControlFlow operations to the LLVM dialect";
let description = [{
Convert ControlFlow operations into LLVM IR dialect operations.
If other operations are present and their results are required by the LLVM
IR dialect operations, the pass will fail. Any LLVM IR operations or types
already present in the IR will be kept as is.
}];
let constructor = "mlir::cf::createConvertControlFlowToLLVMPass()";
let dependentDialects = ["LLVM::LLVMDialect"];
let options = [
Option<"indexBitwidth", "index-bitwidth", "unsigned",
/*default=kDeriveIndexBitwidthFromDataLayout*/"0",
"Bitwidth of the index type, 0 to use size of machine word">,
];
}
//===----------------------------------------------------------------------===//
// GPUCommon
//===----------------------------------------------------------------------===//
@ -460,6 +482,17 @@ def ReconcileUnrealizedCasts : Pass<"reconcile-unrealized-casts"> {
let constructor = "mlir::createReconcileUnrealizedCastsPass()";
}
//===----------------------------------------------------------------------===//
// SCFToControlFlow
//===----------------------------------------------------------------------===//
def SCFToControlFlow : Pass<"convert-scf-to-cf"> {
let summary = "Convert SCF dialect to ControlFlow dialect, replacing structured"
" control flow with a CFG";
let constructor = "mlir::createConvertSCFToCFPass()";
let dependentDialects = ["cf::ControlFlowDialect"];
}
//===----------------------------------------------------------------------===//
// SCFToOpenMP
//===----------------------------------------------------------------------===//
@ -488,17 +521,6 @@ def SCFToSPIRV : Pass<"convert-scf-to-spirv", "ModuleOp"> {
let dependentDialects = ["spirv::SPIRVDialect"];
}
//===----------------------------------------------------------------------===//
// SCFToStandard
//===----------------------------------------------------------------------===//
def SCFToStandard : Pass<"convert-scf-to-std"> {
let summary = "Convert SCF dialect to Standard dialect, replacing structured"
" control flow with a CFG";
let constructor = "mlir::createLowerToCFGPass()";
let dependentDialects = ["StandardOpsDialect"];
}
//===----------------------------------------------------------------------===//
// SCFToGPU
//===----------------------------------------------------------------------===//
@ -547,7 +569,7 @@ def ConvertShapeConstraints: Pass<"convert-shape-constraints", "FuncOp"> {
computation lowering.
}];
let constructor = "mlir::createConvertShapeConstraintsPass()";
let dependentDialects = ["StandardOpsDialect", "scf::SCFDialect"];
let dependentDialects = ["cf::ControlFlowDialect", "scf::SCFDialect"];
}
//===----------------------------------------------------------------------===//

View File

@ -0,0 +1,28 @@
//===- ConvertSCFToControlFlow.h - Pass entrypoint --------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_CONVERSION_SCFTOCONTROLFLOW_SCFTOCONTROLFLOW_H_
#define MLIR_CONVERSION_SCFTOCONTROLFLOW_SCFTOCONTROLFLOW_H_
#include <memory>
namespace mlir {
class Pass;
class RewritePatternSet;
/// Collect a set of patterns to convert SCF operations to CFG branch-based
/// operations within the ControlFlow dialect.
void populateSCFToControlFlowConversionPatterns(RewritePatternSet &patterns);
/// Creates a pass to convert SCF operations to CFG branch-based operation in
/// the ControlFlow dialect.
std::unique_ptr<Pass> createConvertSCFToCFPass();
} // namespace mlir
#endif // MLIR_CONVERSION_SCFTOCONTROLFLOW_SCFTOCONTROLFLOW_H_

View File

@ -1,31 +0,0 @@
//===- ConvertSCFToStandard.h - Pass entrypoint -----------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_CONVERSION_SCFTOSTANDARD_SCFTOSTANDARD_H_
#define MLIR_CONVERSION_SCFTOSTANDARD_SCFTOSTANDARD_H_
#include <memory>
#include <vector>
namespace mlir {
struct LogicalResult;
class Pass;
class RewritePatternSet;
/// Collect a set of patterns to lower from scf.for, scf.if, and
/// loop.terminator to CFG operations within the Standard dialect, in particular
/// convert structured control flow into CFG branch-based control flow.
void populateLoopToStdConversionPatterns(RewritePatternSet &patterns);
/// Creates a pass to convert scf.for, scf.if and loop.terminator ops to CFG.
std::unique_ptr<Pass> createLowerToCFGPass();
} // namespace mlir
#endif // MLIR_CONVERSION_SCFTOSTANDARD_SCFTOSTANDARD_H_

View File

@ -26,9 +26,9 @@ def BufferDeallocation : Pass<"buffer-deallocation", "FuncOp"> {
#map0 = affine_map<(d0) -> (d0)>
module {
func @condBranch(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
cond_br %arg0, ^bb1, ^bb2
cf.cond_br %arg0, ^bb1, ^bb2
^bb1:
br ^bb3(%arg1 : memref<2xf32>)
cf.br ^bb3(%arg1 : memref<2xf32>)
^bb2:
%0 = memref.alloc() : memref<2xf32>
linalg.generic {
@ -40,7 +40,7 @@ def BufferDeallocation : Pass<"buffer-deallocation", "FuncOp"> {
%tmp1 = exp %gen1_arg0 : f32
linalg.yield %tmp1 : f32
}: memref<2xf32>, memref<2xf32>
br ^bb3(%0 : memref<2xf32>)
cf.br ^bb3(%0 : memref<2xf32>)
^bb3(%1: memref<2xf32>):
"memref.copy"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
return
@ -55,11 +55,11 @@ def BufferDeallocation : Pass<"buffer-deallocation", "FuncOp"> {
#map0 = affine_map<(d0) -> (d0)>
module {
func @condBranch(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
cond_br %arg0, ^bb1, ^bb2
cf.cond_br %arg0, ^bb1, ^bb2
^bb1: // pred: ^bb0
%0 = memref.alloc() : memref<2xf32>
memref.copy(%arg1, %0) : memref<2xf32>, memref<2xf32>
br ^bb3(%0 : memref<2xf32>)
cf.br ^bb3(%0 : memref<2xf32>)
^bb2: // pred: ^bb0
%1 = memref.alloc() : memref<2xf32>
linalg.generic {
@ -74,7 +74,7 @@ def BufferDeallocation : Pass<"buffer-deallocation", "FuncOp"> {
%2 = memref.alloc() : memref<2xf32>
memref.copy(%1, %2) : memref<2xf32>, memref<2xf32>
dealloc %1 : memref<2xf32>
br ^bb3(%2 : memref<2xf32>)
cf.br ^bb3(%2 : memref<2xf32>)
^bb3(%3: memref<2xf32>): // 2 preds: ^bb1, ^bb2
memref.copy(%3, %arg2) : memref<2xf32>, memref<2xf32>
dealloc %3 : memref<2xf32>

View File

@ -6,6 +6,7 @@ add_subdirectory(ArmSVE)
add_subdirectory(AMX)
add_subdirectory(Bufferization)
add_subdirectory(Complex)
add_subdirectory(ControlFlow)
add_subdirectory(DLTI)
add_subdirectory(EmitC)
add_subdirectory(GPU)

View File

@ -0,0 +1 @@
add_subdirectory(IR)

View File

@ -0,0 +1,2 @@
add_mlir_dialect(ControlFlowOps cf ControlFlowOps)
add_mlir_doc(ControlFlowOps ControlFlowDialect Dialects/ -gen-dialect-doc)

View File

@ -0,0 +1,21 @@
//===- ControlFlow.h - ControlFlow Dialect ----------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the ControlFlow dialect.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_CONTROLFLOW_IR_CONTROLFLOW_H
#define MLIR_DIALECT_CONTROLFLOW_IR_CONTROLFLOW_H
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/IR/Dialect.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOpsDialect.h.inc"
#endif // MLIR_DIALECT_CONTROLFLOW_IR_CONTROLFLOW_H

View File

@ -0,0 +1,30 @@
//===- ControlFlowOps.h - ControlFlow Operations ----------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the operations of the ControlFlow dialect.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_CONTROLFLOW_IR_CONTROLFLOWOPS_H
#define MLIR_DIALECT_CONTROLFLOW_IR_CONTROLFLOWOPS_H
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
namespace mlir {
class PatternRewriter;
} // namespace mlir
#define GET_OP_CLASSES
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h.inc"
#endif // MLIR_DIALECT_CONTROLFLOW_IR_CONTROLFLOWOPS_H

View File

@ -0,0 +1,313 @@
//===- ControlFlowOps.td - ControlFlow operations ----------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains definitions for the operations within the ControlFlow
// dialect.
//
//===----------------------------------------------------------------------===//
#ifndef STANDARD_OPS
#define STANDARD_OPS
include "mlir/IR/OpAsmInterface.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
def ControlFlow_Dialect : Dialect {
let name = "cf";
let cppNamespace = "::mlir::cf";
let dependentDialects = ["arith::ArithmeticDialect"];
let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed;
let description = [{
This dialect contains low-level, i.e. non-region based, control flow
constructs. These constructs generally represent control flow directly
on SSA blocks of a control flow graph.
}];
}
class CF_Op<string mnemonic, list<Trait> traits = []> :
Op<ControlFlow_Dialect, mnemonic, traits>;
//===----------------------------------------------------------------------===//
// AssertOp
//===----------------------------------------------------------------------===//
def AssertOp : CF_Op<"assert"> {
let summary = "Assert operation with message attribute";
let description = [{
Assert operation with single boolean operand and an error message attribute.
If the argument is `true` this operation has no effect. Otherwise, the
program execution will abort. The provided error message may be used by a
runtime to propagate the error to the user.
Example:
```mlir
assert %b, "Expected ... to be true"
```
}];
let arguments = (ins I1:$arg, StrAttr:$msg);
let assemblyFormat = "$arg `,` $msg attr-dict";
let hasCanonicalizeMethod = 1;
}
//===----------------------------------------------------------------------===//
// BranchOp
//===----------------------------------------------------------------------===//
def BranchOp : CF_Op<"br", [
DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
NoSideEffect, Terminator
]> {
let summary = "branch operation";
let description = [{
The `cf.br` operation represents a direct branch operation to a given
block. The operands of this operation are forwarded to the successor block,
and the number and type of the operands must match the arguments of the
target block.
Example:
```mlir
^bb2:
%2 = call @someFn()
cf.br ^bb3(%2 : tensor<*xf32>)
^bb3(%3: tensor<*xf32>):
```
}];
let arguments = (ins Variadic<AnyType>:$destOperands);
let successors = (successor AnySuccessor:$dest);
let builders = [
OpBuilder<(ins "Block *":$dest,
CArg<"ValueRange", "{}">:$destOperands), [{
$_state.addSuccessors(dest);
$_state.addOperands(destOperands);
}]>];
let extraClassDeclaration = [{
void setDest(Block *block);
/// Erase the operand at 'index' from the operand list.
void eraseOperand(unsigned index);
}];
let hasCanonicalizeMethod = 1;
let assemblyFormat = [{
$dest (`(` $destOperands^ `:` type($destOperands) `)`)? attr-dict
}];
}
//===----------------------------------------------------------------------===//
// CondBranchOp
//===----------------------------------------------------------------------===//
def CondBranchOp : CF_Op<"cond_br",
[AttrSizedOperandSegments,
DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
NoSideEffect, Terminator]> {
let summary = "conditional branch operation";
let description = [{
The `cond_br` terminator operation represents a conditional branch on a
boolean (1-bit integer) value. If the bit is set, then the first destination
is jumped to; if it is false, the second destination is chosen. The count
and types of operands must align with the arguments in the corresponding
target blocks.
The MLIR conditional branch operation is not allowed to target the entry
block for a region. The two destinations of the conditional branch operation
are allowed to be the same.
The following example illustrates a function with a conditional branch
operation that targets the same block.
Example:
```mlir
func @select(%a: i32, %b: i32, %flag: i1) -> i32 {
// Both targets are the same, operands differ
cond_br %flag, ^bb1(%a : i32), ^bb1(%b : i32)
^bb1(%x : i32) :
return %x : i32
}
```
}];
let arguments = (ins I1:$condition,
Variadic<AnyType>:$trueDestOperands,
Variadic<AnyType>:$falseDestOperands);
let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest);
let builders = [
OpBuilder<(ins "Value":$condition, "Block *":$trueDest,
"ValueRange":$trueOperands, "Block *":$falseDest,
"ValueRange":$falseOperands), [{
build($_builder, $_state, condition, trueOperands, falseOperands, trueDest,
falseDest);
}]>,
OpBuilder<(ins "Value":$condition, "Block *":$trueDest,
"Block *":$falseDest, CArg<"ValueRange", "{}">:$falseOperands), [{
build($_builder, $_state, condition, trueDest, ValueRange(), falseDest,
falseOperands);
}]>];
let extraClassDeclaration = [{
// These are the indices into the dests list.
enum { trueIndex = 0, falseIndex = 1 };
// Accessors for operands to the 'true' destination.
Value getTrueOperand(unsigned idx) {
assert(idx < getNumTrueOperands());
return getOperand(getTrueDestOperandIndex() + idx);
}
void setTrueOperand(unsigned idx, Value value) {
assert(idx < getNumTrueOperands());
setOperand(getTrueDestOperandIndex() + idx, value);
}
unsigned getNumTrueOperands() { return getTrueOperands().size(); }
/// Erase the operand at 'index' from the true operand list.
void eraseTrueOperand(unsigned index) {
getTrueDestOperandsMutable().erase(index);
}
// Accessors for operands to the 'false' destination.
Value getFalseOperand(unsigned idx) {
assert(idx < getNumFalseOperands());
return getOperand(getFalseDestOperandIndex() + idx);
}
void setFalseOperand(unsigned idx, Value value) {
assert(idx < getNumFalseOperands());
setOperand(getFalseDestOperandIndex() + idx, value);
}
operand_range getTrueOperands() { return getTrueDestOperands(); }
operand_range getFalseOperands() { return getFalseDestOperands(); }
unsigned getNumFalseOperands() { return getFalseOperands().size(); }
/// Erase the operand at 'index' from the false operand list.
void eraseFalseOperand(unsigned index) {
getFalseDestOperandsMutable().erase(index);
}
private:
/// Get the index of the first true destination operand.
unsigned getTrueDestOperandIndex() { return 1; }
/// Get the index of the first false destination operand.
unsigned getFalseDestOperandIndex() {
return getTrueDestOperandIndex() + getNumTrueOperands();
}
}];
let hasCanonicalizer = 1;
let assemblyFormat = [{
$condition `,`
$trueDest (`(` $trueDestOperands^ `:` type($trueDestOperands) `)`)? `,`
$falseDest (`(` $falseDestOperands^ `:` type($falseDestOperands) `)`)?
attr-dict
}];
}
//===----------------------------------------------------------------------===//
// SwitchOp
//===----------------------------------------------------------------------===//
def SwitchOp : CF_Op<"switch",
[AttrSizedOperandSegments,
DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
NoSideEffect, Terminator]> {
let summary = "switch operation";
let description = [{
The `switch` terminator operation represents a switch on a signless integer
value. If the flag matches one of the specified cases, then the
corresponding destination is jumped to. If the flag does not match any of
the cases, the default destination is jumped to. The count and types of
operands must align with the arguments in the corresponding target blocks.
Example:
```mlir
switch %flag : i32, [
default: ^bb1(%a : i32),
42: ^bb1(%b : i32),
43: ^bb3(%c : i32)
]
```
}];
let arguments = (ins
AnyInteger:$flag,
Variadic<AnyType>:$defaultOperands,
VariadicOfVariadic<AnyType, "case_operand_segments">:$caseOperands,
OptionalAttr<AnyIntElementsAttr>:$case_values,
I32ElementsAttr:$case_operand_segments
);
let successors = (successor
AnySuccessor:$defaultDestination,
VariadicSuccessor<AnySuccessor>:$caseDestinations
);
let builders = [
OpBuilder<(ins "Value":$flag,
"Block *":$defaultDestination,
"ValueRange":$defaultOperands,
CArg<"ArrayRef<APInt>", "{}">:$caseValues,
CArg<"BlockRange", "{}">:$caseDestinations,
CArg<"ArrayRef<ValueRange>", "{}">:$caseOperands)>,
OpBuilder<(ins "Value":$flag,
"Block *":$defaultDestination,
"ValueRange":$defaultOperands,
CArg<"ArrayRef<int32_t>", "{}">:$caseValues,
CArg<"BlockRange", "{}">:$caseDestinations,
CArg<"ArrayRef<ValueRange>", "{}">:$caseOperands)>,
OpBuilder<(ins "Value":$flag,
"Block *":$defaultDestination,
"ValueRange":$defaultOperands,
CArg<"DenseIntElementsAttr", "{}">:$caseValues,
CArg<"BlockRange", "{}">:$caseDestinations,
CArg<"ArrayRef<ValueRange>", "{}">:$caseOperands)>
];
let assemblyFormat = [{
$flag `:` type($flag) `,` `[` `\n`
custom<SwitchOpCases>(ref(type($flag)),$defaultDestination,
$defaultOperands,
type($defaultOperands),
$case_values,
$caseDestinations,
$caseOperands,
type($caseOperands))
`]`
attr-dict
}];
let extraClassDeclaration = [{
/// Return the operands for the case destination block at the given index.
OperandRange getCaseOperands(unsigned index) {
return getCaseOperands()[index];
}
/// Return a mutable range of operands for the case destination block at the
/// given index.
MutableOperandRange getCaseOperandsMutable(unsigned index) {
return getCaseOperandsMutable()[index];
}
}];
let hasCanonicalizer = 1;
let hasVerifier = 1;
}
#endif // STANDARD_OPS

View File

@ -84,15 +84,15 @@ def ExecuteRegionOp : SCF_Op<"execute_region"> {
affine.for %i = 0 to 100 {
"foo"() : () -> ()
%v = scf.execute_region -> i64 {
cond_br %cond, ^bb1, ^bb2
cf.cond_br %cond, ^bb1, ^bb2
^bb1:
%c1 = arith.constant 1 : i64
br ^bb3(%c1 : i64)
cf.br ^bb3(%c1 : i64)
^bb2:
%c2 = arith.constant 2 : i64
br ^bb3(%c2 : i64)
cf.br ^bb3(%c2 : i64)
^bb3(%x : i64):
scf.yield %x : i64

View File

@ -14,7 +14,7 @@
#ifndef MLIR_DIALECT_STANDARDOPS_IR_OPS_H
#define MLIR_DIALECT_STANDARDOPS_IR_OPS_H
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
@ -24,7 +24,6 @@
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/VectorInterfaces.h"
// Pull in all enum type definitions and utility function declarations.
#include "mlir/Dialect/StandardOps/IR/OpsEnums.h.inc"

View File

@ -20,12 +20,11 @@ include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/VectorInterfaces.td"
def StandardOps_Dialect : Dialect {
let name = "std";
let cppNamespace = "::mlir";
let dependentDialects = ["arith::ArithmeticDialect"];
let dependentDialects = ["cf::ControlFlowDialect"];
let hasConstantMaterializer = 1;
let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed;
}
@ -42,78 +41,6 @@ class Std_Op<string mnemonic, list<Trait> traits = []> :
let parser = [{ return ::parse$cppClass(parser, result); }];
}
//===----------------------------------------------------------------------===//
// AssertOp
//===----------------------------------------------------------------------===//
def AssertOp : Std_Op<"assert"> {
let summary = "Assert operation with message attribute";
let description = [{
Assert operation with single boolean operand and an error message attribute.
If the argument is `true` this operation has no effect. Otherwise, the
program execution will abort. The provided error message may be used by a
runtime to propagate the error to the user.
Example:
```mlir
assert %b, "Expected ... to be true"
```
}];
let arguments = (ins I1:$arg, StrAttr:$msg);
let assemblyFormat = "$arg `,` $msg attr-dict";
let hasCanonicalizeMethod = 1;
}
//===----------------------------------------------------------------------===//
// BranchOp
//===----------------------------------------------------------------------===//
def BranchOp : Std_Op<"br",
[DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
NoSideEffect, Terminator]> {
let summary = "branch operation";
let description = [{
The `br` operation represents a branch operation in a function.
The operation takes variable number of operands and produces no results.
The operand number and types for each successor must match the arguments of
the block successor.
Example:
```mlir
^bb2:
%2 = call @someFn()
br ^bb3(%2 : tensor<*xf32>)
^bb3(%3: tensor<*xf32>):
```
}];
let arguments = (ins Variadic<AnyType>:$destOperands);
let successors = (successor AnySuccessor:$dest);
let builders = [
OpBuilder<(ins "Block *":$dest,
CArg<"ValueRange", "{}">:$destOperands), [{
$_state.addSuccessors(dest);
$_state.addOperands(destOperands);
}]>];
let extraClassDeclaration = [{
void setDest(Block *block);
/// Erase the operand at 'index' from the operand list.
void eraseOperand(unsigned index);
}];
let hasCanonicalizeMethod = 1;
let assemblyFormat = [{
$dest (`(` $destOperands^ `:` type($destOperands) `)`)? attr-dict
}];
}
//===----------------------------------------------------------------------===//
// CallOp
//===----------------------------------------------------------------------===//
@ -246,121 +173,6 @@ def CallIndirectOp : Std_Op<"call_indirect", [
"$callee `(` $callee_operands `)` attr-dict `:` type($callee)";
}
//===----------------------------------------------------------------------===//
// CondBranchOp
//===----------------------------------------------------------------------===//
def CondBranchOp : Std_Op<"cond_br",
[AttrSizedOperandSegments,
DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
NoSideEffect, Terminator]> {
let summary = "conditional branch operation";
let description = [{
The `cond_br` terminator operation represents a conditional branch on a
boolean (1-bit integer) value. If the bit is set, then the first destination
is jumped to; if it is false, the second destination is chosen. The count
and types of operands must align with the arguments in the corresponding
target blocks.
The MLIR conditional branch operation is not allowed to target the entry
block for a region. The two destinations of the conditional branch operation
are allowed to be the same.
The following example illustrates a function with a conditional branch
operation that targets the same block.
Example:
```mlir
func @select(%a: i32, %b: i32, %flag: i1) -> i32 {
// Both targets are the same, operands differ
cond_br %flag, ^bb1(%a : i32), ^bb1(%b : i32)
^bb1(%x : i32) :
return %x : i32
}
```
}];
let arguments = (ins I1:$condition,
Variadic<AnyType>:$trueDestOperands,
Variadic<AnyType>:$falseDestOperands);
let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest);
let builders = [
OpBuilder<(ins "Value":$condition, "Block *":$trueDest,
"ValueRange":$trueOperands, "Block *":$falseDest,
"ValueRange":$falseOperands), [{
build($_builder, $_state, condition, trueOperands, falseOperands, trueDest,
falseDest);
}]>,
OpBuilder<(ins "Value":$condition, "Block *":$trueDest,
"Block *":$falseDest, CArg<"ValueRange", "{}">:$falseOperands), [{
build($_builder, $_state, condition, trueDest, ValueRange(), falseDest,
falseOperands);
}]>];
let extraClassDeclaration = [{
// These are the indices into the dests list.
enum { trueIndex = 0, falseIndex = 1 };
// Accessors for operands to the 'true' destination.
Value getTrueOperand(unsigned idx) {
assert(idx < getNumTrueOperands());
return getOperand(getTrueDestOperandIndex() + idx);
}
void setTrueOperand(unsigned idx, Value value) {
assert(idx < getNumTrueOperands());
setOperand(getTrueDestOperandIndex() + idx, value);
}
unsigned getNumTrueOperands() { return getTrueOperands().size(); }
/// Erase the operand at 'index' from the true operand list.
void eraseTrueOperand(unsigned index) {
getTrueDestOperandsMutable().erase(index);
}
// Accessors for operands to the 'false' destination.
Value getFalseOperand(unsigned idx) {
assert(idx < getNumFalseOperands());
return getOperand(getFalseDestOperandIndex() + idx);
}
void setFalseOperand(unsigned idx, Value value) {
assert(idx < getNumFalseOperands());
setOperand(getFalseDestOperandIndex() + idx, value);
}
operand_range getTrueOperands() { return getTrueDestOperands(); }
operand_range getFalseOperands() { return getFalseDestOperands(); }
unsigned getNumFalseOperands() { return getFalseOperands().size(); }
/// Erase the operand at 'index' from the false operand list.
void eraseFalseOperand(unsigned index) {
getFalseDestOperandsMutable().erase(index);
}
private:
/// Get the index of the first true destination operand.
unsigned getTrueDestOperandIndex() { return 1; }
/// Get the index of the first false destination operand.
unsigned getFalseDestOperandIndex() {
return getTrueDestOperandIndex() + getNumTrueOperands();
}
}];
let hasCanonicalizer = 1;
let assemblyFormat = [{
$condition `,`
$trueDest (`(` $trueDestOperands^ `:` type($trueDestOperands) `)`)? `,`
$falseDest (`(` $falseDestOperands^ `:` type($falseDestOperands) `)`)?
attr-dict
}];
}
//===----------------------------------------------------------------------===//
// ConstantOp
//===----------------------------------------------------------------------===//
@ -443,93 +255,4 @@ def ReturnOp : Std_Op<"return", [NoSideEffect, HasParent<"FuncOp">,
let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
// SwitchOp
//===----------------------------------------------------------------------===//
def SwitchOp : Std_Op<"switch",
[AttrSizedOperandSegments,
DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
NoSideEffect, Terminator]> {
let summary = "switch operation";
let description = [{
The `switch` terminator operation represents a switch on a signless integer
value. If the flag matches one of the specified cases, then the
corresponding destination is jumped to. If the flag does not match any of
the cases, the default destination is jumped to. The count and types of
operands must align with the arguments in the corresponding target blocks.
Example:
```mlir
switch %flag : i32, [
default: ^bb1(%a : i32),
42: ^bb1(%b : i32),
43: ^bb3(%c : i32)
]
```
}];
let arguments = (ins
AnyInteger:$flag,
Variadic<AnyType>:$defaultOperands,
VariadicOfVariadic<AnyType, "case_operand_segments">:$caseOperands,
OptionalAttr<AnyIntElementsAttr>:$case_values,
I32ElementsAttr:$case_operand_segments
);
let successors = (successor
AnySuccessor:$defaultDestination,
VariadicSuccessor<AnySuccessor>:$caseDestinations
);
let builders = [
OpBuilder<(ins "Value":$flag,
"Block *":$defaultDestination,
"ValueRange":$defaultOperands,
CArg<"ArrayRef<APInt>", "{}">:$caseValues,
CArg<"BlockRange", "{}">:$caseDestinations,
CArg<"ArrayRef<ValueRange>", "{}">:$caseOperands)>,
OpBuilder<(ins "Value":$flag,
"Block *":$defaultDestination,
"ValueRange":$defaultOperands,
CArg<"ArrayRef<int32_t>", "{}">:$caseValues,
CArg<"BlockRange", "{}">:$caseDestinations,
CArg<"ArrayRef<ValueRange>", "{}">:$caseOperands)>,
OpBuilder<(ins "Value":$flag,
"Block *":$defaultDestination,
"ValueRange":$defaultOperands,
CArg<"DenseIntElementsAttr", "{}">:$caseValues,
CArg<"BlockRange", "{}">:$caseDestinations,
CArg<"ArrayRef<ValueRange>", "{}">:$caseOperands)>
];
let assemblyFormat = [{
$flag `:` type($flag) `,` `[` `\n`
custom<SwitchOpCases>(ref(type($flag)),$defaultDestination,
$defaultOperands,
type($defaultOperands),
$case_values,
$caseDestinations,
$caseOperands,
type($caseOperands))
`]`
attr-dict
}];
let extraClassDeclaration = [{
/// Return the operands for the case destination block at the given index.
OperandRange getCaseOperands(unsigned index) {
return getCaseOperands()[index];
}
/// Return a mutable range of operands for the case destination block at the
/// given index.
MutableOperandRange getCaseOperandsMutable(unsigned index) {
return getCaseOperandsMutable()[index];
}
}];
let hasCanonicalizer = 1;
let hasVerifier = 1;
}
#endif // STANDARD_OPS

View File

@ -22,6 +22,7 @@
#include "mlir/Dialect/Async/IR/Async.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
@ -61,6 +62,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
arm_neon::ArmNeonDialect,
async::AsyncDialect,
bufferization::BufferizationDialect,
cf::ControlFlowDialect,
complex::ComplexDialect,
DLTIDialect,
emitc::EmitCDialect,

View File

@ -6,6 +6,8 @@ add_subdirectory(AsyncToLLVM)
add_subdirectory(BufferizationToMemRef)
add_subdirectory(ComplexToLLVM)
add_subdirectory(ComplexToStandard)
add_subdirectory(ControlFlowToLLVM)
add_subdirectory(ControlFlowToSPIRV)
add_subdirectory(GPUCommon)
add_subdirectory(GPUToNVVM)
add_subdirectory(GPUToROCDL)
@ -25,10 +27,10 @@ add_subdirectory(OpenACCToSCF)
add_subdirectory(OpenMPToLLVM)
add_subdirectory(PDLToPDLInterp)
add_subdirectory(ReconcileUnrealizedCasts)
add_subdirectory(SCFToControlFlow)
add_subdirectory(SCFToGPU)
add_subdirectory(SCFToOpenMP)
add_subdirectory(SCFToSPIRV)
add_subdirectory(SCFToStandard)
add_subdirectory(ShapeToStandard)
add_subdirectory(SPIRVToLLVM)
add_subdirectory(StandardToLLVM)

View File

@ -0,0 +1,21 @@
add_mlir_conversion_library(MLIRControlFlowToLLVM
ControlFlowToLLVM.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ControlFlowToLLVM
DEPENDS
MLIRConversionPassIncGen
intrinsics_gen
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIRAnalysis
MLIRControlFlow
MLIRLLVMCommonConversion
MLIRLLVMIR
MLIRPass
MLIRTransformUtils
)

View File

@ -0,0 +1,148 @@
//===- ControlFlowToLLVM.cpp - ControlFlow to LLVM dialect conversion -----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a pass to convert MLIR standard and builtin dialects
// into the LLVM IR dialect.
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
#include "../PassDetail.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
#include <functional>
using namespace mlir;
#define PASS_NAME "convert-cf-to-llvm"
namespace {
/// Lower `std.assert`. The default lowering calls the `abort` function if the
/// assertion is violated and has no effect otherwise. The failure message is
/// ignored by the default lowering but should be propagated by any custom
/// lowering.
struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
using ConvertOpToLLVMPattern<cf::AssertOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
// Insert the `abort` declaration if necessary.
auto module = op->getParentOfType<ModuleOp>();
auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
if (!abortFunc) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {});
abortFunc = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(),
"abort", abortFuncTy);
}
// Split block at `assert` operation.
Block *opBlock = rewriter.getInsertionBlock();
auto opPosition = rewriter.getInsertionPoint();
Block *continuationBlock = rewriter.splitBlock(opBlock, opPosition);
// Generate IR to call `abort`.
Block *failureBlock = rewriter.createBlock(opBlock->getParent());
rewriter.create<LLVM::CallOp>(loc, abortFunc, llvm::None);
rewriter.create<LLVM::UnreachableOp>(loc);
// Generate assertion test.
rewriter.setInsertionPointToEnd(opBlock);
rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
op, adaptor.getArg(), continuationBlock, failureBlock);
return success();
}
};
// Base class for LLVM IR lowering terminator operations with successors.
template <typename SourceOp, typename TargetOp>
struct OneToOneLLVMTerminatorLowering
: public ConvertOpToLLVMPattern<SourceOp> {
using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
using Base = OneToOneLLVMTerminatorLowering<SourceOp, TargetOp>;
LogicalResult
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<TargetOp>(op, adaptor.getOperands(),
op->getSuccessors(), op->getAttrs());
return success();
}
};
// FIXME: this should be tablegen'ed as well.
struct BranchOpLowering
: public OneToOneLLVMTerminatorLowering<cf::BranchOp, LLVM::BrOp> {
using Base::Base;
};
struct CondBranchOpLowering
: public OneToOneLLVMTerminatorLowering<cf::CondBranchOp, LLVM::CondBrOp> {
using Base::Base;
};
struct SwitchOpLowering
: public OneToOneLLVMTerminatorLowering<cf::SwitchOp, LLVM::SwitchOp> {
using Base::Base;
};
} // namespace
void mlir::cf::populateControlFlowToLLVMConversionPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
// clang-format off
patterns.add<
AssertOpLowering,
BranchOpLowering,
CondBranchOpLowering,
SwitchOpLowering>(converter);
// clang-format on
}
//===----------------------------------------------------------------------===//
// Pass Definition
//===----------------------------------------------------------------------===//
namespace {
/// A pass converting MLIR operations into the LLVM IR dialect.
struct ConvertControlFlowToLLVM
: public ConvertControlFlowToLLVMBase<ConvertControlFlowToLLVM> {
ConvertControlFlowToLLVM() = default;
/// Run the dialect converter on the module.
void runOnOperation() override {
LLVMConversionTarget target(getContext());
RewritePatternSet patterns(&getContext());
LowerToLLVMOptions options(&getContext());
if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
options.overrideIndexBitwidth(indexBitwidth);
LLVMTypeConverter converter(&getContext(), options);
mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
}
};
} // namespace
std::unique_ptr<Pass> mlir::cf::createConvertControlFlowToLLVMPass() {
return std::make_unique<ConvertControlFlowToLLVM>();
}

View File

@ -0,0 +1,19 @@
add_mlir_conversion_library(MLIRControlFlowToSPIRV
ControlFlowToSPIRV.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SPIRV
${MLIR_MAIN_INCLUDE_DIR}/mlir/IR
DEPENDS
MLIRConversionPassIncGen
LINK_LIBS PUBLIC
MLIRIR
MLIRControlFlow
MLIRPass
MLIRSPIRV
MLIRSPIRVConversion
MLIRSupport
MLIRTransformUtils
)

View File

@ -0,0 +1,73 @@
//===- ControlFlowToSPIRV.cpp - ControlFlow to SPIR-V Patterns ------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements patterns to convert standard dialect to SPIR-V dialect.
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h"
#include "../SPIRVCommon/Pattern.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "cf-to-spirv-pattern"
using namespace mlir;
//===----------------------------------------------------------------------===//
// Operation conversion
//===----------------------------------------------------------------------===//
namespace {
/// Converts cf.br to spv.Branch.
struct BranchOpPattern final : public OpConversionPattern<cf::BranchOp> {
using OpConversionPattern<cf::BranchOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(cf::BranchOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<spirv::BranchOp>(op, op.getDest(),
adaptor.getDestOperands());
return success();
}
};
/// Converts cf.cond_br to spv.BranchConditional.
struct CondBranchOpPattern final
: public OpConversionPattern<cf::CondBranchOp> {
using OpConversionPattern<cf::CondBranchOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(cf::CondBranchOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<spirv::BranchConditionalOp>(
op, op.getCondition(), op.getTrueDest(), adaptor.getTrueDestOperands(),
op.getFalseDest(), adaptor.getFalseDestOperands());
return success();
}
};
} // namespace
//===----------------------------------------------------------------------===//
// Pattern population
//===----------------------------------------------------------------------===//
void mlir::cf::populateControlFlowToSPIRVPatterns(
SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
patterns.add<BranchOpPattern, CondBranchOpPattern>(typeConverter, context);
}

View File

@ -14,6 +14,7 @@
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
@ -172,8 +173,8 @@ struct LowerGpuOpsToNVVMOpsPass
populateGpuRewritePatterns(patterns);
(void)applyPatternsAndFoldGreedily(m, std::move(patterns));
mlir::arith::populateArithmeticToLLVMConversionPatterns(converter,
llvmPatterns);
arith::populateArithmeticToLLVMConversionPatterns(converter, llvmPatterns);
cf::populateControlFlowToLLVMConversionPatterns(converter, llvmPatterns);
populateStdToLLVMConversionPatterns(converter, llvmPatterns);
populateMemRefToLLVMConversionPatterns(converter, llvmPatterns);
populateGpuToNVVMConversionPatterns(converter, llvmPatterns);

View File

@ -19,7 +19,7 @@ add_mlir_conversion_library(MLIRLinalgToLLVM
MLIRLLVMCommonConversion
MLIRLLVMIR
MLIRMemRefToLLVM
MLIRSCFToStandard
MLIRSCFToControlFlow
MLIRTransforms
MLIRVectorToLLVM
MLIRVectorToSCF

View File

@ -14,7 +14,7 @@
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"

View File

@ -433,7 +433,7 @@ struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
/// +---------------------------------+
/// | <code before the AtomicRMWOp> |
/// | <compute initial %loaded> |
/// | br loop(%loaded) |
/// | cf.br loop(%loaded) |
/// +---------------------------------+
/// |
/// -------| |
@ -444,7 +444,7 @@ struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
/// | | %pair = cmpxchg |
/// | | %ok = %pair[0] |
/// | | %new = %pair[1] |
/// | | cond_br %ok, end, loop(%new) |
/// | | cf.cond_br %ok, end, loop(%new) |
/// | +--------------------------------+
/// | | |
/// |----------- |

View File

@ -10,6 +10,7 @@
#include "../PassDetail.h"
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
@ -66,7 +67,8 @@ void ConvertOpenMPToLLVMPass::runOnOperation() {
// Convert to OpenMP operations with LLVM IR dialect
RewritePatternSet patterns(&getContext());
LLVMTypeConverter converter(&getContext());
mlir::arith::populateArithmeticToLLVMConversionPatterns(converter, patterns);
arith::populateArithmeticToLLVMConversionPatterns(converter, patterns);
cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
populateMemRefToLLVMConversionPatterns(converter, patterns);
populateStdToLLVMConversionPatterns(converter, patterns);
populateOpenMPToLLVMConversionPatterns(converter, patterns);

View File

@ -29,6 +29,10 @@ namespace arith {
class ArithmeticDialect;
} // namespace arith
namespace cf {
class ControlFlowDialect;
} // namespace cf
namespace complex {
class ComplexDialect;
} // namespace complex

View File

@ -1,8 +1,8 @@
add_mlir_conversion_library(MLIRSCFToStandard
SCFToStandard.cpp
add_mlir_conversion_library(MLIRSCFToControlFlow
SCFToControlFlow.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/SCFToStandard
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/SCFToControlFlow
DEPENDS
MLIRConversionPassIncGen
@ -12,6 +12,7 @@ add_mlir_conversion_library(MLIRSCFToStandard
LINK_LIBS PUBLIC
MLIRArithmetic
MLIRControlFlow
MLIRSCF
MLIRTransforms
)

View File

@ -1,4 +1,4 @@
//===- SCFToStandard.cpp - ControlFlow to CFG conversion ------------------===//
//===- SCFToControlFlow.cpp - SCF to CF conversion ------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@ -11,11 +11,11 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "../PassDetail.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
@ -29,7 +29,8 @@ using namespace mlir::scf;
namespace {
struct SCFToStandardPass : public SCFToStandardBase<SCFToStandardPass> {
struct SCFToControlFlowPass
: public SCFToControlFlowBase<SCFToControlFlowPass> {
void runOnOperation() override;
};
@ -57,7 +58,7 @@ struct SCFToStandardPass : public SCFToStandardBase<SCFToStandardPass> {
// | <code before the ForOp> |
// | <definitions of %init...> |
// | <compute initial %iv value> |
// | br cond(%iv, %init...) |
// | cf.br cond(%iv, %init...) |
// +---------------------------------+
// |
// -------| |
@ -65,7 +66,7 @@ struct SCFToStandardPass : public SCFToStandardBase<SCFToStandardPass> {
// | +--------------------------------+
// | | cond(%iv, %init...): |
// | | <compare %iv to upper bound> |
// | | cond_br %r, body, end |
// | | cf.cond_br %r, body, end |
// | +--------------------------------+
// | | |
// | | -------------|
@ -83,7 +84,7 @@ struct SCFToStandardPass : public SCFToStandardBase<SCFToStandardPass> {
// | | <body contents> | |
// | | <operands of yield = %yields>| |
// | | %new_iv =<add step to %iv> | |
// | | br cond(%new_iv, %yields) | |
// | | cf.br cond(%new_iv, %yields) | |
// | +--------------------------------+ |
// | | |
// |----------- |--------------------
@ -125,7 +126,7 @@ struct ForLowering : public OpRewritePattern<ForOp> {
//
// +--------------------------------+
// | <code before the IfOp> |
// | cond_br %cond, %then, %else |
// | cf.cond_br %cond, %then, %else |
// +--------------------------------+
// | |
// | --------------|
@ -133,7 +134,7 @@ struct ForLowering : public OpRewritePattern<ForOp> {
// +--------------------------------+ |
// | then: | |
// | <then contents> | |
// | br continue | |
// | cf.br continue | |
// +--------------------------------+ |
// | |
// |---------- |-------------
@ -141,7 +142,7 @@ struct ForLowering : public OpRewritePattern<ForOp> {
// | +--------------------------------+
// | | else: |
// | | <else contents> |
// | | br continue |
// | | cf.br continue |
// | +--------------------------------+
// | |
// ------| |
@ -155,7 +156,7 @@ struct ForLowering : public OpRewritePattern<ForOp> {
//
// +--------------------------------+
// | <code before the IfOp> |
// | cond_br %cond, %then, %else |
// | cf.cond_br %cond, %then, %else |
// +--------------------------------+
// | |
// | --------------|
@ -163,7 +164,7 @@ struct ForLowering : public OpRewritePattern<ForOp> {
// +--------------------------------+ |
// | then: | |
// | <then contents> | |
// | br dom(%args...) | |
// | cf.br dom(%args...) | |
// +--------------------------------+ |
// | |
// |---------- |-------------
@ -171,14 +172,14 @@ struct ForLowering : public OpRewritePattern<ForOp> {
// | +--------------------------------+
// | | else: |
// | | <else contents> |
// | | br dom(%args...) |
// | | cf.br dom(%args...) |
// | +--------------------------------+
// | |
// ------| |
// v v
// +--------------------------------+
// | dom(%args...): |
// | br continue |
// | cf.br continue |
// +--------------------------------+
// |
// v
@ -218,7 +219,7 @@ struct ParallelLowering : public OpRewritePattern<mlir::scf::ParallelOp> {
///
/// +---------------------------------+
/// | <code before the WhileOp> |
/// | br ^before(%operands...) |
/// | cf.br ^before(%operands...) |
/// +---------------------------------+
/// |
/// -------| |
@ -233,7 +234,7 @@ struct ParallelLowering : public OpRewritePattern<mlir::scf::ParallelOp> {
/// | +--------------------------------+
/// | | ^before-last:
/// | | %cond = <compute condition> |
/// | | cond_br %cond, |
/// | | cf.cond_br %cond, |
/// | | ^after(%vals...), ^cont |
/// | +--------------------------------+
/// | | |
@ -249,7 +250,7 @@ struct ParallelLowering : public OpRewritePattern<mlir::scf::ParallelOp> {
/// | +--------------------------------+ |
/// | | ^after-last: | |
/// | | %yields... = <some payload> | |
/// | | br ^before(%yields...) | |
/// | | cf.br ^before(%yields...) | |
/// | +--------------------------------+ |
/// | | |
/// |----------- |--------------------
@ -321,7 +322,7 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
SmallVector<Value, 8> loopCarried;
loopCarried.push_back(stepped);
loopCarried.append(terminator->operand_begin(), terminator->operand_end());
rewriter.create<BranchOp>(loc, conditionBlock, loopCarried);
rewriter.create<cf::BranchOp>(loc, conditionBlock, loopCarried);
rewriter.eraseOp(terminator);
// Compute loop bounds before branching to the condition.
@ -337,15 +338,16 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
destOperands.push_back(lowerBound);
auto iterOperands = forOp.getIterOperands();
destOperands.append(iterOperands.begin(), iterOperands.end());
rewriter.create<BranchOp>(loc, conditionBlock, destOperands);
rewriter.create<cf::BranchOp>(loc, conditionBlock, destOperands);
// With the body block done, we can fill in the condition block.
rewriter.setInsertionPointToEnd(conditionBlock);
auto comparison = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, iv, upperBound);
rewriter.create<CondBranchOp>(loc, comparison, firstBodyBlock,
ArrayRef<Value>(), endBlock, ArrayRef<Value>());
rewriter.create<cf::CondBranchOp>(loc, comparison, firstBodyBlock,
ArrayRef<Value>(), endBlock,
ArrayRef<Value>());
// The result of the loop operation is the values of the condition block
// arguments except the induction variable on the last iteration.
rewriter.replaceOp(forOp, conditionBlock->getArguments().drop_front());
@ -369,7 +371,7 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
continueBlock =
rewriter.createBlock(remainingOpsBlock, ifOp.getResultTypes(),
SmallVector<Location>(ifOp.getNumResults(), loc));
rewriter.create<BranchOp>(loc, remainingOpsBlock);
rewriter.create<cf::BranchOp>(loc, remainingOpsBlock);
}
// Move blocks from the "then" region to the region containing 'scf.if',
@ -379,7 +381,7 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
Operation *thenTerminator = thenRegion.back().getTerminator();
ValueRange thenTerminatorOperands = thenTerminator->getOperands();
rewriter.setInsertionPointToEnd(&thenRegion.back());
rewriter.create<BranchOp>(loc, continueBlock, thenTerminatorOperands);
rewriter.create<cf::BranchOp>(loc, continueBlock, thenTerminatorOperands);
rewriter.eraseOp(thenTerminator);
rewriter.inlineRegionBefore(thenRegion, continueBlock);
@ -393,15 +395,15 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
Operation *elseTerminator = elseRegion.back().getTerminator();
ValueRange elseTerminatorOperands = elseTerminator->getOperands();
rewriter.setInsertionPointToEnd(&elseRegion.back());
rewriter.create<BranchOp>(loc, continueBlock, elseTerminatorOperands);
rewriter.create<cf::BranchOp>(loc, continueBlock, elseTerminatorOperands);
rewriter.eraseOp(elseTerminator);
rewriter.inlineRegionBefore(elseRegion, continueBlock);
}
rewriter.setInsertionPointToEnd(condBlock);
rewriter.create<CondBranchOp>(loc, ifOp.getCondition(), thenBlock,
/*trueArgs=*/ArrayRef<Value>(), elseBlock,
/*falseArgs=*/ArrayRef<Value>());
rewriter.create<cf::CondBranchOp>(loc, ifOp.getCondition(), thenBlock,
/*trueArgs=*/ArrayRef<Value>(), elseBlock,
/*falseArgs=*/ArrayRef<Value>());
// Ok, we're done!
rewriter.replaceOp(ifOp, continueBlock->getArguments());
@ -419,13 +421,13 @@ ExecuteRegionLowering::matchAndRewrite(ExecuteRegionOp op,
auto &region = op.getRegion();
rewriter.setInsertionPointToEnd(condBlock);
rewriter.create<BranchOp>(loc, &region.front());
rewriter.create<cf::BranchOp>(loc, &region.front());
for (Block &block : region) {
if (auto terminator = dyn_cast<scf::YieldOp>(block.getTerminator())) {
ValueRange terminatorOperands = terminator->getOperands();
rewriter.setInsertionPointToEnd(&block);
rewriter.create<BranchOp>(loc, remainingOpsBlock, terminatorOperands);
rewriter.create<cf::BranchOp>(loc, remainingOpsBlock, terminatorOperands);
rewriter.eraseOp(terminator);
}
}
@ -538,20 +540,21 @@ LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp,
// Branch to the "before" region.
rewriter.setInsertionPointToEnd(currentBlock);
rewriter.create<BranchOp>(loc, before, whileOp.getInits());
rewriter.create<cf::BranchOp>(loc, before, whileOp.getInits());
// Replace terminators with branches. Assuming bodies are SESE, which holds
// given only the patterns from this file, we only need to look at the last
// block. This should be reconsidered if we allow break/continue in SCF.
rewriter.setInsertionPointToEnd(beforeLast);
auto condOp = cast<ConditionOp>(beforeLast->getTerminator());
rewriter.replaceOpWithNewOp<CondBranchOp>(condOp, condOp.getCondition(),
after, condOp.getArgs(),
continuation, ValueRange());
rewriter.replaceOpWithNewOp<cf::CondBranchOp>(condOp, condOp.getCondition(),
after, condOp.getArgs(),
continuation, ValueRange());
rewriter.setInsertionPointToEnd(afterLast);
auto yieldOp = cast<scf::YieldOp>(afterLast->getTerminator());
rewriter.replaceOpWithNewOp<BranchOp>(yieldOp, before, yieldOp.getResults());
rewriter.replaceOpWithNewOp<cf::BranchOp>(yieldOp, before,
yieldOp.getResults());
// Replace the op with values "yielded" from the "before" region, which are
// visible by dominance.
@ -593,14 +596,14 @@ DoWhileLowering::matchAndRewrite(WhileOp whileOp,
// Branch to the "before" region.
rewriter.setInsertionPointToEnd(currentBlock);
rewriter.create<BranchOp>(whileOp.getLoc(), before, whileOp.getInits());
rewriter.create<cf::BranchOp>(whileOp.getLoc(), before, whileOp.getInits());
// Loop around the "before" region based on condition.
rewriter.setInsertionPointToEnd(beforeLast);
auto condOp = cast<ConditionOp>(beforeLast->getTerminator());
rewriter.replaceOpWithNewOp<CondBranchOp>(condOp, condOp.getCondition(),
before, condOp.getArgs(),
continuation, ValueRange());
rewriter.replaceOpWithNewOp<cf::CondBranchOp>(condOp, condOp.getCondition(),
before, condOp.getArgs(),
continuation, ValueRange());
// Replace the op with values "yielded" from the "before" region, which are
// visible by dominance.
@ -609,17 +612,18 @@ DoWhileLowering::matchAndRewrite(WhileOp whileOp,
return success();
}
void mlir::populateLoopToStdConversionPatterns(RewritePatternSet &patterns) {
void mlir::populateSCFToControlFlowConversionPatterns(
RewritePatternSet &patterns) {
patterns.add<ForLowering, IfLowering, ParallelLowering, WhileLowering,
ExecuteRegionLowering>(patterns.getContext());
patterns.add<DoWhileLowering>(patterns.getContext(), /*benefit=*/2);
}
void SCFToStandardPass::runOnOperation() {
void SCFToControlFlowPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
populateLoopToStdConversionPatterns(patterns);
// Configure conversion to lower out scf.for, scf.if, scf.parallel and
// scf.while. Anything else is fine.
populateSCFToControlFlowConversionPatterns(patterns);
// Configure conversion to lower out SCF operations.
ConversionTarget target(getContext());
target.addIllegalOp<scf::ForOp, scf::IfOp, scf::ParallelOp, scf::WhileOp,
scf::ExecuteRegionOp>();
@ -629,6 +633,6 @@ void SCFToStandardPass::runOnOperation() {
signalPassFailure();
}
std::unique_ptr<Pass> mlir::createLowerToCFGPass() {
return std::make_unique<SCFToStandardPass>();
std::unique_ptr<Pass> mlir::createConvertSCFToCFPass() {
return std::make_unique<SCFToControlFlowPass>();
}

View File

@ -9,6 +9,7 @@
#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
#include "../PassDetail.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
@ -29,7 +30,7 @@ public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(shape::CstrRequireOp op,
PatternRewriter &rewriter) const override {
rewriter.create<AssertOp>(op.getLoc(), op.getPred(), op.getMsgAttr());
rewriter.create<cf::AssertOp>(op.getLoc(), op.getPred(), op.getMsgAttr());
rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true);
return success();
}

View File

@ -14,6 +14,7 @@ add_mlir_conversion_library(MLIRStandardToLLVM
LINK_LIBS PUBLIC
MLIRAnalysis
MLIRArithmeticToLLVM
MLIRControlFlowToLLVM
MLIRDataLayoutInterfaces
MLIRLLVMCommonConversion
MLIRLLVMIR

View File

@ -14,6 +14,7 @@
#include "../PassDetail.h"
#include "mlir/Analysis/DataLayoutAnalysis.h"
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
@ -387,48 +388,6 @@ struct BarePtrFuncOpConversion : public FuncOpConversionBase {
}
};
/// Lower `std.assert`. The default lowering calls the `abort` function if the
/// assertion is violated and has no effect otherwise. The failure message is
/// ignored by the default lowering but should be propagated by any custom
/// lowering.
struct AssertOpLowering : public ConvertOpToLLVMPattern<AssertOp> {
using ConvertOpToLLVMPattern<AssertOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(AssertOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
// Insert the `abort` declaration if necessary.
auto module = op->getParentOfType<ModuleOp>();
auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
if (!abortFunc) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {});
abortFunc = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(),
"abort", abortFuncTy);
}
// Split block at `assert` operation.
Block *opBlock = rewriter.getInsertionBlock();
auto opPosition = rewriter.getInsertionPoint();
Block *continuationBlock = rewriter.splitBlock(opBlock, opPosition);
// Generate IR to call `abort`.
Block *failureBlock = rewriter.createBlock(opBlock->getParent());
rewriter.create<LLVM::CallOp>(loc, abortFunc, llvm::None);
rewriter.create<LLVM::UnreachableOp>(loc);
// Generate assertion test.
rewriter.setInsertionPointToEnd(opBlock);
rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
op, adaptor.getArg(), continuationBlock, failureBlock);
return success();
}
};
struct ConstantOpLowering : public ConvertOpToLLVMPattern<ConstantOp> {
using ConvertOpToLLVMPattern<ConstantOp>::ConvertOpToLLVMPattern;
@ -550,22 +509,6 @@ struct UnrealizedConversionCastOpLowering
}
};
// Base class for LLVM IR lowering terminator operations with successors.
template <typename SourceOp, typename TargetOp>
struct OneToOneLLVMTerminatorLowering
: public ConvertOpToLLVMPattern<SourceOp> {
using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
using Super = OneToOneLLVMTerminatorLowering<SourceOp, TargetOp>;
LogicalResult
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<TargetOp>(op, adaptor.getOperands(),
op->getSuccessors(), op->getAttrs());
return success();
}
};
// Special lowering pattern for `ReturnOps`. Unlike all other operations,
// `ReturnOp` interacts with the function signature and must have as many
// operands as the function has return values. Because in LLVM IR, functions
@ -633,21 +576,6 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<ReturnOp> {
return success();
}
};
// FIXME: this should be tablegen'ed as well.
struct BranchOpLowering
: public OneToOneLLVMTerminatorLowering<BranchOp, LLVM::BrOp> {
using Super::Super;
};
struct CondBranchOpLowering
: public OneToOneLLVMTerminatorLowering<CondBranchOp, LLVM::CondBrOp> {
using Super::Super;
};
struct SwitchOpLowering
: public OneToOneLLVMTerminatorLowering<SwitchOp, LLVM::SwitchOp> {
using Super::Super;
};
} // namespace
void mlir::populateStdToLLVMFuncOpConversionPattern(
@ -663,14 +591,10 @@ void mlir::populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter,
populateStdToLLVMFuncOpConversionPattern(converter, patterns);
// clang-format off
patterns.add<
AssertOpLowering,
BranchOpLowering,
CallIndirectOpLowering,
CallOpLowering,
CondBranchOpLowering,
ConstantOpLowering,
ReturnOpLowering,
SwitchOpLowering>(converter);
ReturnOpLowering>(converter);
// clang-format on
}
@ -721,6 +645,7 @@ struct LLVMLoweringPass : public ConvertStandardToLLVMBase<LLVMLoweringPass> {
RewritePatternSet patterns(&getContext());
populateStdToLLVMConversionPatterns(typeConverter, patterns);
arith::populateArithmeticToLLVMConversionPatterns(typeConverter, patterns);
cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns);
LLVMConversionTarget target(getContext());
if (failed(applyPartialConversion(m, target, std::move(patterns))))

View File

@ -11,6 +11,7 @@ add_mlir_conversion_library(MLIRStandardToSPIRV
LINK_LIBS PUBLIC
MLIRArithmeticToSPIRV
MLIRControlFlowToSPIRV
MLIRIR
MLIRMathToSPIRV
MLIRMemRef

View File

@ -46,24 +46,6 @@ public:
ConversionPatternRewriter &rewriter) const override;
};
/// Converts std.br to spv.Branch.
struct BranchOpPattern final : public OpConversionPattern<BranchOp> {
using OpConversionPattern<BranchOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(BranchOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
/// Converts std.cond_br to spv.BranchConditional.
struct CondBranchOpPattern final : public OpConversionPattern<CondBranchOp> {
using OpConversionPattern<CondBranchOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(CondBranchOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
/// Converts tensor.extract into loading using access chains from SPIR-V local
/// variables.
class TensorExtractPattern final
@ -146,31 +128,6 @@ ReturnOpPattern::matchAndRewrite(ReturnOp returnOp, OpAdaptor adaptor,
return success();
}
//===----------------------------------------------------------------------===//
// BranchOpPattern
//===----------------------------------------------------------------------===//
LogicalResult
BranchOpPattern::matchAndRewrite(BranchOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
rewriter.replaceOpWithNewOp<spirv::BranchOp>(op, op.getDest(),
adaptor.getDestOperands());
return success();
}
//===----------------------------------------------------------------------===//
// CondBranchOpPattern
//===----------------------------------------------------------------------===//
LogicalResult CondBranchOpPattern::matchAndRewrite(
CondBranchOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
rewriter.replaceOpWithNewOp<spirv::BranchConditionalOp>(
op, op.getCondition(), op.getTrueDest(), adaptor.getTrueDestOperands(),
op.getFalseDest(), adaptor.getFalseDestOperands());
return success();
}
//===----------------------------------------------------------------------===//
// Pattern population
//===----------------------------------------------------------------------===//
@ -189,8 +146,7 @@ void populateStandardToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
spirv::ElementwiseOpPattern<arith::MinSIOp, spirv::GLSLSMinOp>,
spirv::ElementwiseOpPattern<arith::MinUIOp, spirv::GLSLUMinOp>,
ReturnOpPattern, BranchOpPattern, CondBranchOpPattern>(typeConverter,
context);
ReturnOpPattern>(typeConverter, context);
}
void populateTensorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,

View File

@ -13,6 +13,7 @@
#include "mlir/Conversion/StandardToSPIRV/StandardToSPIRVPass.h"
#include "../PassDetail.h"
#include "mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h"
#include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h"
#include "mlir/Conversion/MathToSPIRV/MathToSPIRV.h"
#include "mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
@ -40,9 +41,11 @@ void ConvertStandardToSPIRVPass::runOnOperation() {
options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes;
SPIRVTypeConverter typeConverter(targetAttr, options);
// TODO ArithmeticToSPIRV cannot be applied separately to StandardToSPIRV
// TODO ArithmeticToSPIRV/ControlFlowToSPIRV cannot be applied separately to
// StandardToSPIRV
RewritePatternSet patterns(context);
arith::populateArithmeticToSPIRVPatterns(typeConverter, patterns);
cf::populateControlFlowToSPIRVPatterns(typeConverter, patterns);
populateMathToSPIRVPatterns(typeConverter, patterns);
populateStandardToSPIRVPatterns(typeConverter, patterns);
populateTensorToSPIRVPatterns(typeConverter, /*byteCountThreshold=*/64,

View File

@ -15,6 +15,7 @@
#include "mlir/Analysis/Liveness.h"
#include "mlir/Dialect/Async/IR/Async.h"
#include "mlir/Dialect/Async/Passes.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/PatternMatch.h"
@ -169,11 +170,11 @@ private:
///
/// ^entry:
/// %token = async.runtime.create : !async.token
/// cond_br %cond, ^bb1, ^bb2
/// cf.cond_br %cond, ^bb1, ^bb2
/// ^bb1:
/// async.runtime.await %token
/// async.runtime.drop_ref %token
/// br ^bb2
/// cf.br ^bb2
/// ^bb2:
/// return
///
@ -185,14 +186,14 @@ private:
///
/// ^entry:
/// %token = async.runtime.create : !async.token
/// cond_br %cond, ^bb1, ^reference_counting
/// cf.cond_br %cond, ^bb1, ^reference_counting
/// ^bb1:
/// async.runtime.await %token
/// async.runtime.drop_ref %token
/// br ^bb2
/// cf.br ^bb2
/// ^reference_counting:
/// async.runtime.drop_ref %token
/// br ^bb2
/// cf.br ^bb2
/// ^bb2:
/// return
///
@ -208,7 +209,7 @@ private:
/// async.coro.suspend %ret, ^suspend, ^resume, ^cleanup
/// ^resume:
/// %0 = async.runtime.load %value
/// br ^cleanup
/// cf.br ^cleanup
/// ^cleanup:
/// ...
/// ^suspend:
@ -406,7 +407,7 @@ AsyncRuntimeRefCountingPass::addDropRefInDivergentLivenessSuccessor(
refCountingBlock = &successor->getParent()->emplaceBlock();
refCountingBlock->moveBefore(successor);
OpBuilder builder = OpBuilder::atBlockEnd(refCountingBlock);
builder.create<BranchOp>(value.getLoc(), successor);
builder.create<cf::BranchOp>(value.getLoc(), successor);
}
OpBuilder builder = OpBuilder::atBlockBegin(refCountingBlock);

View File

@ -12,10 +12,11 @@
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Async/IR/Async.h"
#include "mlir/Dialect/Async/Passes.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BlockAndValueMapping.h"
@ -105,18 +106,18 @@ struct CoroMachinery {
/// %value = <async value> : !async.value<T> // create async value
/// %id = async.coro.id // create a coroutine id
/// %hdl = async.coro.begin %id // create a coroutine handle
/// br ^preexisting_entry_block
/// cf.br ^preexisting_entry_block
///
/// /* preexisting blocks modified to branch to the cleanup block */
///
/// ^set_error: // this block created lazily only if needed (see code below)
/// async.runtime.set_error %token : !async.token
/// async.runtime.set_error %value : !async.value<T>
/// br ^cleanup
/// cf.br ^cleanup
///
/// ^cleanup:
/// async.coro.free %hdl // delete the coroutine state
/// br ^suspend
/// cf.br ^suspend
///
/// ^suspend:
/// async.coro.end %hdl // marks the end of a coroutine
@ -147,7 +148,7 @@ static CoroMachinery setupCoroMachinery(FuncOp func) {
auto coroIdOp = builder.create<CoroIdOp>(CoroIdType::get(ctx));
auto coroHdlOp =
builder.create<CoroBeginOp>(CoroHandleType::get(ctx), coroIdOp.id());
builder.create<BranchOp>(originalEntryBlock);
builder.create<cf::BranchOp>(originalEntryBlock);
Block *cleanupBlock = func.addBlock();
Block *suspendBlock = func.addBlock();
@ -159,7 +160,7 @@ static CoroMachinery setupCoroMachinery(FuncOp func) {
builder.create<CoroFreeOp>(coroIdOp.id(), coroHdlOp.handle());
// Branch into the suspend block.
builder.create<BranchOp>(suspendBlock);
builder.create<cf::BranchOp>(suspendBlock);
// ------------------------------------------------------------------------ //
// Coroutine suspend block: mark the end of a coroutine and return allocated
@ -186,7 +187,7 @@ static CoroMachinery setupCoroMachinery(FuncOp func) {
Operation *terminator = block.getTerminator();
if (auto yield = dyn_cast<YieldOp>(terminator)) {
builder.setInsertionPointToEnd(&block);
builder.create<BranchOp>(cleanupBlock);
builder.create<cf::BranchOp>(cleanupBlock);
}
}
@ -227,7 +228,7 @@ static Block *setupSetErrorBlock(CoroMachinery &coro) {
builder.create<RuntimeSetErrorOp>(retValue);
// Branch into the cleanup block.
builder.create<BranchOp>(coro.cleanup);
builder.create<cf::BranchOp>(coro.cleanup);
return coro.setError;
}
@ -305,7 +306,7 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
// Async resume operation (execution will be resumed in a thread managed by
// the async runtime).
{
BranchOp branch = cast<BranchOp>(coro.entry->getTerminator());
cf::BranchOp branch = cast<cf::BranchOp>(coro.entry->getTerminator());
builder.setInsertionPointToEnd(coro.entry);
// Save the coroutine state: async.coro.save
@ -419,8 +420,8 @@ public:
isError, builder.create<arith::ConstantOp>(
loc, i1, builder.getIntegerAttr(i1, 1)));
builder.create<AssertOp>(notError,
"Awaited async operand is in error state");
builder.create<cf::AssertOp>(notError,
"Awaited async operand is in error state");
}
// Inside the coroutine we convert await operation into coroutine suspension
@ -452,11 +453,11 @@ public:
// Check if the awaited value is in the error state.
builder.setInsertionPointToStart(resume);
auto isError = builder.create<RuntimeIsErrorOp>(loc, i1, operand);
builder.create<CondBranchOp>(isError,
/*trueDest=*/setupSetErrorBlock(coro),
/*trueArgs=*/ArrayRef<Value>(),
/*falseDest=*/continuation,
/*falseArgs=*/ArrayRef<Value>());
builder.create<cf::CondBranchOp>(isError,
/*trueDest=*/setupSetErrorBlock(coro),
/*trueArgs=*/ArrayRef<Value>(),
/*falseDest=*/continuation,
/*falseArgs=*/ArrayRef<Value>());
// Make sure that replacement value will be constructed in the
// continuation block.
@ -560,18 +561,18 @@ private:
};
//===----------------------------------------------------------------------===//
// Convert std.assert operation to cond_br into `set_error` block.
// Convert std.assert operation to cf.cond_br into `set_error` block.
//===----------------------------------------------------------------------===//
class AssertOpLowering : public OpConversionPattern<AssertOp> {
class AssertOpLowering : public OpConversionPattern<cf::AssertOp> {
public:
AssertOpLowering(MLIRContext *ctx,
llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
: OpConversionPattern<AssertOp>(ctx),
: OpConversionPattern<cf::AssertOp>(ctx),
outlinedFunctions(outlinedFunctions) {}
LogicalResult
matchAndRewrite(AssertOp op, OpAdaptor adaptor,
matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Check if assert operation is inside the async coroutine function.
auto func = op->template getParentOfType<FuncOp>();
@ -585,11 +586,11 @@ public:
Block *cont = rewriter.splitBlock(op->getBlock(), Block::iterator(op));
rewriter.setInsertionPointToEnd(cont->getPrevNode());
rewriter.create<CondBranchOp>(loc, adaptor.getArg(),
/*trueDest=*/cont,
/*trueArgs=*/ArrayRef<Value>(),
/*falseDest=*/setupSetErrorBlock(coro),
/*falseArgs=*/ArrayRef<Value>());
rewriter.create<cf::CondBranchOp>(loc, adaptor.getArg(),
/*trueDest=*/cont,
/*trueArgs=*/ArrayRef<Value>(),
/*falseDest=*/setupSetErrorBlock(coro),
/*falseArgs=*/ArrayRef<Value>());
rewriter.eraseOp(op);
return success();
@ -765,7 +766,7 @@ void AsyncToAsyncRuntimePass::runOnOperation() {
// and we have to make sure that structured control flow operations with async
// operations in nested regions will be converted to branch-based control flow
// before we add the coroutine basic blocks.
populateLoopToStdConversionPatterns(asyncPatterns);
populateSCFToControlFlowConversionPatterns(asyncPatterns);
// Async lowering does not use type converter because it must preserve all
// types for async.runtime operations.
@ -792,14 +793,15 @@ void AsyncToAsyncRuntimePass::runOnOperation() {
});
return !walkResult.wasInterrupted();
});
runtimeTarget.addLegalOp<AssertOp, arith::XOrIOp, arith::ConstantOp,
ConstantOp, BranchOp, CondBranchOp>();
runtimeTarget.addLegalOp<cf::AssertOp, arith::XOrIOp, arith::ConstantOp,
ConstantOp, cf::BranchOp, cf::CondBranchOp>();
// Assertions must be converted to runtime errors inside async functions.
runtimeTarget.addDynamicallyLegalOp<AssertOp>([&](AssertOp op) -> bool {
auto func = op->getParentOfType<FuncOp>();
return outlinedFunctions.find(func) == outlinedFunctions.end();
});
runtimeTarget.addDynamicallyLegalOp<cf::AssertOp>(
[&](cf::AssertOp op) -> bool {
auto func = op->getParentOfType<FuncOp>();
return outlinedFunctions.find(func) == outlinedFunctions.end();
});
if (eliminateBlockingAwaitOps)
runtimeTarget.addDynamicallyLegalOp<RuntimeAwaitOp>(

View File

@ -17,7 +17,7 @@ add_mlir_dialect_library(MLIRAsyncTransforms
MLIRIR
MLIRPass
MLIRSCF
MLIRSCFToStandard
MLIRSCFToControlFlow
MLIRStandard
MLIRTransforms
MLIRTransformUtils

View File

@ -18,12 +18,12 @@
// (using the BufferViewFlowAnalysis class). Consider the following example:
//
// ^bb0(%arg0):
// cond_br %cond, ^bb1, ^bb2
// cf.cond_br %cond, ^bb1, ^bb2
// ^bb1:
// br ^exit(%arg0)
// cf.br ^exit(%arg0)
// ^bb2:
// %new_value = ...
// br ^exit(%new_value)
// cf.br ^exit(%new_value)
// ^exit(%arg1):
// return %arg1;
//

View File

@ -6,6 +6,7 @@ add_subdirectory(Async)
add_subdirectory(AMX)
add_subdirectory(Bufferization)
add_subdirectory(Complex)
add_subdirectory(ControlFlow)
add_subdirectory(DLTI)
add_subdirectory(EmitC)
add_subdirectory(GPU)

View File

@ -0,0 +1 @@
add_subdirectory(IR)

View File

@ -0,0 +1,15 @@
add_mlir_dialect_library(MLIRControlFlow
ControlFlowOps.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ControlFlow/IR
DEPENDS
MLIRControlFlowOpsIncGen
LINK_LIBS PUBLIC
MLIRArithmetic
MLIRControlFlowInterfaces
MLIRIR
MLIRSideEffectInterfaces
)

View File

@ -0,0 +1,891 @@
//===- ControlFlowOps.cpp - ControlFlow Operations ------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/CommonFolders.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/MathExtras.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
#include <numeric>
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOpsDialect.cpp.inc"
using namespace mlir;
using namespace mlir::cf;
//===----------------------------------------------------------------------===//
// ControlFlowDialect Interfaces
//===----------------------------------------------------------------------===//
namespace {
/// This class defines the interface for handling inlining with control flow
/// operations.
struct ControlFlowInlinerInterface : public DialectInlinerInterface {
using DialectInlinerInterface::DialectInlinerInterface;
~ControlFlowInlinerInterface() override = default;
/// All control flow operations can be inlined.
bool isLegalToInline(Operation *call, Operation *callable,
bool wouldBeCloned) const final {
return true;
}
bool isLegalToInline(Operation *, Region *, bool,
BlockAndValueMapping &) const final {
return true;
}
/// ControlFlow terminator operations don't really need any special handing.
void handleTerminator(Operation *op, Block *newDest) const final {}
};
} // namespace
//===----------------------------------------------------------------------===//
// ControlFlowDialect
//===----------------------------------------------------------------------===//
void ControlFlowDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
>();
addInterfaces<ControlFlowInlinerInterface>();
}
//===----------------------------------------------------------------------===//
// AssertOp
//===----------------------------------------------------------------------===//
LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) {
// Erase assertion if argument is constant true.
if (matchPattern(op.getArg(), m_One())) {
rewriter.eraseOp(op);
return success();
}
return failure();
}
//===----------------------------------------------------------------------===//
// BranchOp
//===----------------------------------------------------------------------===//
/// Given a successor, try to collapse it to a new destination if it only
/// contains a passthrough unconditional branch. If the successor is
/// collapsable, `successor` and `successorOperands` are updated to reference
/// the new destination and values. `argStorage` is used as storage if operands
/// to the collapsed successor need to be remapped. It must outlive uses of
/// successorOperands.
static LogicalResult collapseBranch(Block *&successor,
ValueRange &successorOperands,
SmallVectorImpl<Value> &argStorage) {
// Check that the successor only contains a unconditional branch.
if (std::next(successor->begin()) != successor->end())
return failure();
// Check that the terminator is an unconditional branch.
BranchOp successorBranch = dyn_cast<BranchOp>(successor->getTerminator());
if (!successorBranch)
return failure();
// Check that the arguments are only used within the terminator.
for (BlockArgument arg : successor->getArguments()) {
for (Operation *user : arg.getUsers())
if (user != successorBranch)
return failure();
}
// Don't try to collapse branches to infinite loops.
Block *successorDest = successorBranch.getDest();
if (successorDest == successor)
return failure();
// Update the operands to the successor. If the branch parent has no
// arguments, we can use the branch operands directly.
OperandRange operands = successorBranch.getOperands();
if (successor->args_empty()) {
successor = successorDest;
successorOperands = operands;
return success();
}
// Otherwise, we need to remap any argument operands.
for (Value operand : operands) {
BlockArgument argOperand = operand.dyn_cast<BlockArgument>();
if (argOperand && argOperand.getOwner() == successor)
argStorage.push_back(successorOperands[argOperand.getArgNumber()]);
else
argStorage.push_back(operand);
}
successor = successorDest;
successorOperands = argStorage;
return success();
}
/// Simplify a branch to a block that has a single predecessor. This effectively
/// merges the two blocks.
static LogicalResult
simplifyBrToBlockWithSinglePred(BranchOp op, PatternRewriter &rewriter) {
// Check that the successor block has a single predecessor.
Block *succ = op.getDest();
Block *opParent = op->getBlock();
if (succ == opParent || !llvm::hasSingleElement(succ->getPredecessors()))
return failure();
// Merge the successor into the current block and erase the branch.
rewriter.mergeBlocks(succ, opParent, op.getOperands());
rewriter.eraseOp(op);
return success();
}
/// br ^bb1
/// ^bb1
/// br ^bbN(...)
///
/// -> br ^bbN(...)
///
static LogicalResult simplifyPassThroughBr(BranchOp op,
PatternRewriter &rewriter) {
Block *dest = op.getDest();
ValueRange destOperands = op.getOperands();
SmallVector<Value, 4> destOperandStorage;
// Try to collapse the successor if it points somewhere other than this
// block.
if (dest == op->getBlock() ||
failed(collapseBranch(dest, destOperands, destOperandStorage)))
return failure();
// Create a new branch with the collapsed successor.
rewriter.replaceOpWithNewOp<BranchOp>(op, dest, destOperands);
return success();
}
LogicalResult BranchOp::canonicalize(BranchOp op, PatternRewriter &rewriter) {
return success(succeeded(simplifyBrToBlockWithSinglePred(op, rewriter)) ||
succeeded(simplifyPassThroughBr(op, rewriter)));
}
void BranchOp::setDest(Block *block) { return setSuccessor(block); }
void BranchOp::eraseOperand(unsigned index) { (*this)->eraseOperand(index); }
Optional<MutableOperandRange>
BranchOp::getMutableSuccessorOperands(unsigned index) {
assert(index == 0 && "invalid successor index");
return getDestOperandsMutable();
}
Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) {
return getDest();
}
//===----------------------------------------------------------------------===//
// CondBranchOp
//===----------------------------------------------------------------------===//
namespace {
/// cf.cond_br true, ^bb1, ^bb2
/// -> br ^bb1
/// cf.cond_br false, ^bb1, ^bb2
/// -> br ^bb2
///
struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
using OpRewritePattern<CondBranchOp>::OpRewritePattern;
LogicalResult matchAndRewrite(CondBranchOp condbr,
PatternRewriter &rewriter) const override {
if (matchPattern(condbr.getCondition(), m_NonZero())) {
// True branch taken.
rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
condbr.getTrueOperands());
return success();
}
if (matchPattern(condbr.getCondition(), m_Zero())) {
// False branch taken.
rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
condbr.getFalseOperands());
return success();
}
return failure();
}
};
/// cf.cond_br %cond, ^bb1, ^bb2
/// ^bb1
/// br ^bbN(...)
/// ^bb2
/// br ^bbK(...)
///
/// -> cf.cond_br %cond, ^bbN(...), ^bbK(...)
///
struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> {
using OpRewritePattern<CondBranchOp>::OpRewritePattern;
LogicalResult matchAndRewrite(CondBranchOp condbr,
PatternRewriter &rewriter) const override {
Block *trueDest = condbr.getTrueDest(), *falseDest = condbr.getFalseDest();
ValueRange trueDestOperands = condbr.getTrueOperands();
ValueRange falseDestOperands = condbr.getFalseOperands();
SmallVector<Value, 4> trueDestOperandStorage, falseDestOperandStorage;
// Try to collapse one of the current successors.
LogicalResult collapsedTrue =
collapseBranch(trueDest, trueDestOperands, trueDestOperandStorage);
LogicalResult collapsedFalse =
collapseBranch(falseDest, falseDestOperands, falseDestOperandStorage);
if (failed(collapsedTrue) && failed(collapsedFalse))
return failure();
// Create a new branch with the collapsed successors.
rewriter.replaceOpWithNewOp<CondBranchOp>(condbr, condbr.getCondition(),
trueDest, trueDestOperands,
falseDest, falseDestOperands);
return success();
}
};
/// cf.cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N)
/// -> br ^bb1(A, ..., N)
///
/// cf.cond_br %cond, ^bb1(A), ^bb1(B)
/// -> %select = arith.select %cond, A, B
/// br ^bb1(%select)
///
struct SimplifyCondBranchIdenticalSuccessors
: public OpRewritePattern<CondBranchOp> {
using OpRewritePattern<CondBranchOp>::OpRewritePattern;
LogicalResult matchAndRewrite(CondBranchOp condbr,
PatternRewriter &rewriter) const override {
// Check that the true and false destinations are the same and have the same
// operands.
Block *trueDest = condbr.getTrueDest();
if (trueDest != condbr.getFalseDest())
return failure();
// If all of the operands match, no selects need to be generated.
OperandRange trueOperands = condbr.getTrueOperands();
OperandRange falseOperands = condbr.getFalseOperands();
if (trueOperands == falseOperands) {
rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, trueOperands);
return success();
}
// Otherwise, if the current block is the only predecessor insert selects
// for any mismatched branch operands.
if (trueDest->getUniquePredecessor() != condbr->getBlock())
return failure();
// Generate a select for any operands that differ between the two.
SmallVector<Value, 8> mergedOperands;
mergedOperands.reserve(trueOperands.size());
Value condition = condbr.getCondition();
for (auto it : llvm::zip(trueOperands, falseOperands)) {
if (std::get<0>(it) == std::get<1>(it))
mergedOperands.push_back(std::get<0>(it));
else
mergedOperands.push_back(rewriter.create<arith::SelectOp>(
condbr.getLoc(), condition, std::get<0>(it), std::get<1>(it)));
}
rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, mergedOperands);
return success();
}
};
/// ...
/// cf.cond_br %cond, ^bb1(...), ^bb2(...)
/// ...
/// ^bb1: // has single predecessor
/// ...
/// cf.cond_br %cond, ^bb3(...), ^bb4(...)
///
/// ->
///
/// ...
/// cf.cond_br %cond, ^bb1(...), ^bb2(...)
/// ...
/// ^bb1: // has single predecessor
/// ...
/// br ^bb3(...)
///
struct SimplifyCondBranchFromCondBranchOnSameCondition
: public OpRewritePattern<CondBranchOp> {
using OpRewritePattern<CondBranchOp>::OpRewritePattern;
LogicalResult matchAndRewrite(CondBranchOp condbr,
PatternRewriter &rewriter) const override {
// Check that we have a single distinct predecessor.
Block *currentBlock = condbr->getBlock();
Block *predecessor = currentBlock->getSinglePredecessor();
if (!predecessor)
return failure();
// Check that the predecessor terminates with a conditional branch to this
// block and that it branches on the same condition.
auto predBranch = dyn_cast<CondBranchOp>(predecessor->getTerminator());
if (!predBranch || condbr.getCondition() != predBranch.getCondition())
return failure();
// Fold this branch to an unconditional branch.
if (currentBlock == predBranch.getTrueDest())
rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
condbr.getTrueDestOperands());
else
rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
condbr.getFalseDestOperands());
return success();
}
};
/// cf.cond_br %arg0, ^trueB, ^falseB
///
/// ^trueB:
/// "test.consumer1"(%arg0) : (i1) -> ()
/// ...
///
/// ^falseB:
/// "test.consumer2"(%arg0) : (i1) -> ()
/// ...
///
/// ->
///
/// cf.cond_br %arg0, ^trueB, ^falseB
/// ^trueB:
/// "test.consumer1"(%true) : (i1) -> ()
/// ...
///
/// ^falseB:
/// "test.consumer2"(%false) : (i1) -> ()
/// ...
struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> {
using OpRewritePattern<CondBranchOp>::OpRewritePattern;
LogicalResult matchAndRewrite(CondBranchOp condbr,
PatternRewriter &rewriter) const override {
// Check that we have a single distinct predecessor.
bool replaced = false;
Type ty = rewriter.getI1Type();
// These variables serve to prevent creating duplicate constants
// and hold constant true or false values.
Value constantTrue = nullptr;
Value constantFalse = nullptr;
// TODO These checks can be expanded to encompas any use with only
// either the true of false edge as a predecessor. For now, we fall
// back to checking the single predecessor is given by the true/fasle
// destination, thereby ensuring that only that edge can reach the
// op.
if (condbr.getTrueDest()->getSinglePredecessor()) {
for (OpOperand &use :
llvm::make_early_inc_range(condbr.getCondition().getUses())) {
if (use.getOwner()->getBlock() == condbr.getTrueDest()) {
replaced = true;
if (!constantTrue)
constantTrue = rewriter.create<arith::ConstantOp>(
condbr.getLoc(), ty, rewriter.getBoolAttr(true));
rewriter.updateRootInPlace(use.getOwner(),
[&] { use.set(constantTrue); });
}
}
}
if (condbr.getFalseDest()->getSinglePredecessor()) {
for (OpOperand &use :
llvm::make_early_inc_range(condbr.getCondition().getUses())) {
if (use.getOwner()->getBlock() == condbr.getFalseDest()) {
replaced = true;
if (!constantFalse)
constantFalse = rewriter.create<arith::ConstantOp>(
condbr.getLoc(), ty, rewriter.getBoolAttr(false));
rewriter.updateRootInPlace(use.getOwner(),
[&] { use.set(constantFalse); });
}
}
}
return success(replaced);
}
};
} // namespace
void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
SimplifyCondBranchIdenticalSuccessors,
SimplifyCondBranchFromCondBranchOnSameCondition,
CondBranchTruthPropagation>(context);
}
Optional<MutableOperandRange>
CondBranchOp::getMutableSuccessorOperands(unsigned index) {
assert(index < getNumSuccessors() && "invalid successor index");
return index == trueIndex ? getTrueDestOperandsMutable()
: getFalseDestOperandsMutable();
}
Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
if (IntegerAttr condAttr = operands.front().dyn_cast_or_null<IntegerAttr>())
return condAttr.getValue().isOneValue() ? getTrueDest() : getFalseDest();
return nullptr;
}
//===----------------------------------------------------------------------===//
// SwitchOp
//===----------------------------------------------------------------------===//
void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
Block *defaultDestination, ValueRange defaultOperands,
DenseIntElementsAttr caseValues,
BlockRange caseDestinations,
ArrayRef<ValueRange> caseOperands) {
build(builder, result, value, defaultOperands, caseOperands, caseValues,
defaultDestination, caseDestinations);
}
void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
Block *defaultDestination, ValueRange defaultOperands,
ArrayRef<APInt> caseValues, BlockRange caseDestinations,
ArrayRef<ValueRange> caseOperands) {
DenseIntElementsAttr caseValuesAttr;
if (!caseValues.empty()) {
ShapedType caseValueType = VectorType::get(
static_cast<int64_t>(caseValues.size()), value.getType());
caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
}
build(builder, result, value, defaultDestination, defaultOperands,
caseValuesAttr, caseDestinations, caseOperands);
}
/// <cases> ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)?
/// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )*
static ParseResult parseSwitchOpCases(
OpAsmParser &parser, Type &flagType, Block *&defaultDestination,
SmallVectorImpl<OpAsmParser::OperandType> &defaultOperands,
SmallVectorImpl<Type> &defaultOperandTypes,
DenseIntElementsAttr &caseValues,
SmallVectorImpl<Block *> &caseDestinations,
SmallVectorImpl<SmallVector<OpAsmParser::OperandType>> &caseOperands,
SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) {
if (parser.parseKeyword("default") || parser.parseColon() ||
parser.parseSuccessor(defaultDestination))
return failure();
if (succeeded(parser.parseOptionalLParen())) {
if (parser.parseRegionArgumentList(defaultOperands) ||
parser.parseColonTypeList(defaultOperandTypes) || parser.parseRParen())
return failure();
}
SmallVector<APInt> values;
unsigned bitWidth = flagType.getIntOrFloatBitWidth();
while (succeeded(parser.parseOptionalComma())) {
int64_t value = 0;
if (failed(parser.parseInteger(value)))
return failure();
values.push_back(APInt(bitWidth, value));
Block *destination;
SmallVector<OpAsmParser::OperandType> operands;
SmallVector<Type> operandTypes;
if (failed(parser.parseColon()) ||
failed(parser.parseSuccessor(destination)))
return failure();
if (succeeded(parser.parseOptionalLParen())) {
if (failed(parser.parseRegionArgumentList(operands)) ||
failed(parser.parseColonTypeList(operandTypes)) ||
failed(parser.parseRParen()))
return failure();
}
caseDestinations.push_back(destination);
caseOperands.emplace_back(operands);
caseOperandTypes.emplace_back(operandTypes);
}
if (!values.empty()) {
ShapedType caseValueType =
VectorType::get(static_cast<int64_t>(values.size()), flagType);
caseValues = DenseIntElementsAttr::get(caseValueType, values);
}
return success();
}
static void printSwitchOpCases(
OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination,
OperandRange defaultOperands, TypeRange defaultOperandTypes,
DenseIntElementsAttr caseValues, SuccessorRange caseDestinations,
OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes) {
p << " default: ";
p.printSuccessorAndUseList(defaultDestination, defaultOperands);
if (!caseValues)
return;
for (const auto &it : llvm::enumerate(caseValues.getValues<APInt>())) {
p << ',';
p.printNewline();
p << " ";
p << it.value().getLimitedValue();
p << ": ";
p.printSuccessorAndUseList(caseDestinations[it.index()],
caseOperands[it.index()]);
}
p.printNewline();
}
LogicalResult SwitchOp::verify() {
auto caseValues = getCaseValues();
auto caseDestinations = getCaseDestinations();
if (!caseValues && caseDestinations.empty())
return success();
Type flagType = getFlag().getType();
Type caseValueType = caseValues->getType().getElementType();
if (caseValueType != flagType)
return emitOpError() << "'flag' type (" << flagType
<< ") should match case value type (" << caseValueType
<< ")";
if (caseValues &&
caseValues->size() != static_cast<int64_t>(caseDestinations.size()))
return emitOpError() << "number of case values (" << caseValues->size()
<< ") should match number of "
"case destinations ("
<< caseDestinations.size() << ")";
return success();
}
Optional<MutableOperandRange>
SwitchOp::getMutableSuccessorOperands(unsigned index) {
assert(index < getNumSuccessors() && "invalid successor index");
return index == 0 ? getDefaultOperandsMutable()
: getCaseOperandsMutable(index - 1);
}
Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
Optional<DenseIntElementsAttr> caseValues = getCaseValues();
if (!caseValues)
return getDefaultDestination();
SuccessorRange caseDests = getCaseDestinations();
if (auto value = operands.front().dyn_cast_or_null<IntegerAttr>()) {
for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>()))
if (it.value() == value.getValue())
return caseDests[it.index()];
return getDefaultDestination();
}
return nullptr;
}
/// switch %flag : i32, [
/// default: ^bb1
/// ]
/// -> br ^bb1
static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op,
PatternRewriter &rewriter) {
if (!op.getCaseDestinations().empty())
return failure();
rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
op.getDefaultOperands());
return success();
}
/// switch %flag : i32, [
/// default: ^bb1,
/// 42: ^bb1,
/// 43: ^bb2
/// ]
/// ->
/// switch %flag : i32, [
/// default: ^bb1,
/// 43: ^bb2
/// ]
static LogicalResult
dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter) {
SmallVector<Block *> newCaseDestinations;
SmallVector<ValueRange> newCaseOperands;
SmallVector<APInt> newCaseValues;
bool requiresChange = false;
auto caseValues = op.getCaseValues();
auto caseDests = op.getCaseDestinations();
for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
if (caseDests[it.index()] == op.getDefaultDestination() &&
op.getCaseOperands(it.index()) == op.getDefaultOperands()) {
requiresChange = true;
continue;
}
newCaseDestinations.push_back(caseDests[it.index()]);
newCaseOperands.push_back(op.getCaseOperands(it.index()));
newCaseValues.push_back(it.value());
}
if (!requiresChange)
return failure();
rewriter.replaceOpWithNewOp<SwitchOp>(
op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
newCaseValues, newCaseDestinations, newCaseOperands);
return success();
}
/// Helper for folding a switch with a constant value.
/// switch %c_42 : i32, [
/// default: ^bb1 ,
/// 42: ^bb2,
/// 43: ^bb3
/// ]
/// -> br ^bb2
static void foldSwitch(SwitchOp op, PatternRewriter &rewriter,
const APInt &caseValue) {
auto caseValues = op.getCaseValues();
for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
if (it.value() == caseValue) {
rewriter.replaceOpWithNewOp<BranchOp>(
op, op.getCaseDestinations()[it.index()],
op.getCaseOperands(it.index()));
return;
}
}
rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
op.getDefaultOperands());
}
/// switch %c_42 : i32, [
/// default: ^bb1,
/// 42: ^bb2,
/// 43: ^bb3
/// ]
/// -> br ^bb2
static LogicalResult simplifyConstSwitchValue(SwitchOp op,
PatternRewriter &rewriter) {
APInt caseValue;
if (!matchPattern(op.getFlag(), m_ConstantInt(&caseValue)))
return failure();
foldSwitch(op, rewriter, caseValue);
return success();
}
/// switch %c_42 : i32, [
/// default: ^bb1,
/// 42: ^bb2,
/// ]
/// ^bb2:
/// br ^bb3
/// ->
/// switch %c_42 : i32, [
/// default: ^bb1,
/// 42: ^bb3,
/// ]
static LogicalResult simplifyPassThroughSwitch(SwitchOp op,
PatternRewriter &rewriter) {
SmallVector<Block *> newCaseDests;
SmallVector<ValueRange> newCaseOperands;
SmallVector<SmallVector<Value>> argStorage;
auto caseValues = op.getCaseValues();
auto caseDests = op.getCaseDestinations();
bool requiresChange = false;
for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
Block *caseDest = caseDests[i];
ValueRange caseOperands = op.getCaseOperands(i);
argStorage.emplace_back();
if (succeeded(collapseBranch(caseDest, caseOperands, argStorage.back())))
requiresChange = true;
newCaseDests.push_back(caseDest);
newCaseOperands.push_back(caseOperands);
}
Block *defaultDest = op.getDefaultDestination();
ValueRange defaultOperands = op.getDefaultOperands();
argStorage.emplace_back();
if (succeeded(
collapseBranch(defaultDest, defaultOperands, argStorage.back())))
requiresChange = true;
if (!requiresChange)
return failure();
rewriter.replaceOpWithNewOp<SwitchOp>(op, op.getFlag(), defaultDest,
defaultOperands, caseValues.getValue(),
newCaseDests, newCaseOperands);
return success();
}
/// switch %flag : i32, [
/// default: ^bb1,
/// 42: ^bb2,
/// ]
/// ^bb2:
/// switch %flag : i32, [
/// default: ^bb3,
/// 42: ^bb4
/// ]
/// ->
/// switch %flag : i32, [
/// default: ^bb1,
/// 42: ^bb2,
/// ]
/// ^bb2:
/// br ^bb4
///
/// and
///
/// switch %flag : i32, [
/// default: ^bb1,
/// 42: ^bb2,
/// ]
/// ^bb2:
/// switch %flag : i32, [
/// default: ^bb3,
/// 43: ^bb4
/// ]
/// ->
/// switch %flag : i32, [
/// default: ^bb1,
/// 42: ^bb2,
/// ]
/// ^bb2:
/// br ^bb3
static LogicalResult
simplifySwitchFromSwitchOnSameCondition(SwitchOp op,
PatternRewriter &rewriter) {
// Check that we have a single distinct predecessor.
Block *currentBlock = op->getBlock();
Block *predecessor = currentBlock->getSinglePredecessor();
if (!predecessor)
return failure();
// Check that the predecessor terminates with a switch branch to this block
// and that it branches on the same condition and that this branch isn't the
// default destination.
auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
predSwitch.getDefaultDestination() == currentBlock)
return failure();
// Fold this switch to an unconditional branch.
SuccessorRange predDests = predSwitch.getCaseDestinations();
auto it = llvm::find(predDests, currentBlock);
if (it != predDests.end()) {
Optional<DenseIntElementsAttr> predCaseValues = predSwitch.getCaseValues();
foldSwitch(op, rewriter,
predCaseValues->getValues<APInt>()[it - predDests.begin()]);
} else {
rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
op.getDefaultOperands());
}
return success();
}
/// switch %flag : i32, [
/// default: ^bb1,
/// 42: ^bb2
/// ]
/// ^bb1:
/// switch %flag : i32, [
/// default: ^bb3,
/// 42: ^bb4,
/// 43: ^bb5
/// ]
/// ->
/// switch %flag : i32, [
/// default: ^bb1,
/// 42: ^bb2,
/// ]
/// ^bb1:
/// switch %flag : i32, [
/// default: ^bb3,
/// 43: ^bb5
/// ]
static LogicalResult
simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op,
PatternRewriter &rewriter) {
// Check that we have a single distinct predecessor.
Block *currentBlock = op->getBlock();
Block *predecessor = currentBlock->getSinglePredecessor();
if (!predecessor)
return failure();
// Check that the predecessor terminates with a switch branch to this block
// and that it branches on the same condition and that this branch is the
// default destination.
auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
predSwitch.getDefaultDestination() != currentBlock)
return failure();
// Delete case values that are not possible here.
DenseSet<APInt> caseValuesToRemove;
auto predDests = predSwitch.getCaseDestinations();
auto predCaseValues = predSwitch.getCaseValues();
for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i)
if (currentBlock != predDests[i])
caseValuesToRemove.insert(predCaseValues->getValues<APInt>()[i]);
SmallVector<Block *> newCaseDestinations;
SmallVector<ValueRange> newCaseOperands;
SmallVector<APInt> newCaseValues;
bool requiresChange = false;
auto caseValues = op.getCaseValues();
auto caseDests = op.getCaseDestinations();
for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
if (caseValuesToRemove.contains(it.value())) {
requiresChange = true;
continue;
}
newCaseDestinations.push_back(caseDests[it.index()]);
newCaseOperands.push_back(op.getCaseOperands(it.index()));
newCaseValues.push_back(it.value());
}
if (!requiresChange)
return failure();
rewriter.replaceOpWithNewOp<SwitchOp>(
op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
newCaseValues, newCaseDestinations, newCaseOperands);
return success();
}
void SwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add(&simplifySwitchWithOnlyDefault)
.add(&dropSwitchCasesThatMatchDefault)
.add(&simplifyConstSwitchValue)
.add(&simplifyPassThroughSwitch)
.add(&simplifySwitchFromSwitchOnSameCondition)
.add(&simplifySwitchFromDefaultSwitchOnSameCondition);
}
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"

View File

@ -12,10 +12,10 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/GPU/Passes.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/PatternMatch.h"
@ -44,14 +44,14 @@ struct GpuAllReduceRewriter {
/// workgroup memory.
///
/// %subgroup_reduce = `createSubgroupReduce(%operand)`
/// cond_br %is_first_lane, ^then1, ^continue1
/// cf.cond_br %is_first_lane, ^then1, ^continue1
/// ^then1:
/// store %subgroup_reduce, %workgroup_buffer[%subgroup_id]
/// br ^continue1
/// cf.br ^continue1
/// ^continue1:
/// gpu.barrier
/// %is_valid_subgroup = arith.cmpi "slt" %invocation_idx, %num_subgroups
/// cond_br %is_valid_subgroup, ^then2, ^continue2
/// cf.cond_br %is_valid_subgroup, ^then2, ^continue2
/// ^then2:
/// %partial_reduce = load %workgroup_buffer[%invocation_idx]
/// %all_reduce = `createSubgroupReduce(%partial_reduce)`
@ -194,7 +194,7 @@ private:
// Add branch before inserted body, into body.
block = block->getNextNode();
create<BranchOp>(block, ValueRange());
create<cf::BranchOp>(block, ValueRange());
// Replace all gpu.yield ops with branch out of body.
for (; block != split; block = block->getNextNode()) {
@ -202,7 +202,7 @@ private:
if (!isa<gpu::YieldOp>(terminator))
continue;
rewriter.setInsertionPointToEnd(block);
rewriter.replaceOpWithNewOp<BranchOp>(
rewriter.replaceOpWithNewOp<cf::BranchOp>(
terminator, split, ValueRange(terminator->getOperand(0)));
}
@ -285,17 +285,17 @@ private:
Block *continueBlock = rewriter.splitBlock(elseBlock, elseBlock->begin());
rewriter.setInsertionPointToEnd(currentBlock);
create<CondBranchOp>(condition, thenBlock,
/*trueOperands=*/ArrayRef<Value>(), elseBlock,
/*falseOperands=*/ArrayRef<Value>());
create<cf::CondBranchOp>(condition, thenBlock,
/*trueOperands=*/ArrayRef<Value>(), elseBlock,
/*falseOperands=*/ArrayRef<Value>());
rewriter.setInsertionPointToStart(thenBlock);
auto thenOperands = thenOpsFactory();
create<BranchOp>(continueBlock, thenOperands);
create<cf::BranchOp>(continueBlock, thenOperands);
rewriter.setInsertionPointToStart(elseBlock);
auto elseOperands = elseOpsFactory();
create<BranchOp>(continueBlock, elseOperands);
create<cf::BranchOp>(continueBlock, elseOperands);
assert(thenOperands.size() == elseOperands.size());
rewriter.setInsertionPointToStart(continueBlock);

View File

@ -12,6 +12,7 @@
#include "PassDetail.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/GPU/Passes.h"
@ -186,7 +187,7 @@ static gpu::GPUFuncOp outlineKernelFuncImpl(gpu::LaunchOp launchOp,
Block &launchOpEntry = launchOpBody.front();
Block *clonedLaunchOpEntry = map.lookup(&launchOpEntry);
builder.setInsertionPointToEnd(&entryBlock);
builder.create<BranchOp>(loc, clonedLaunchOpEntry);
builder.create<cf::BranchOp>(loc, clonedLaunchOpEntry);
outlinedFunc.walk([](gpu::TerminatorOp op) {
OpBuilder replacer(op);

View File

@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
@ -254,13 +255,13 @@ struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
DenseSet<BlockArgument> &blockArgsToDetensor) override {
SmallVector<Value> workList;
func->walk([&](CondBranchOp condBr) {
func->walk([&](cf::CondBranchOp condBr) {
for (auto operand : condBr.getOperands()) {
workList.push_back(operand);
}
});
func->walk([&](BranchOp br) {
func->walk([&](cf::BranchOp br) {
for (auto operand : br.getOperands()) {
workList.push_back(operand);
}

View File

@ -9,6 +9,7 @@
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Matchers.h"
@ -165,13 +166,13 @@ struct SingleBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
// "test.foo"() : () -> ()
// %v = scf.execute_region -> i64 {
// %c = "test.cmp"() : () -> i1
// cond_br %c, ^bb2, ^bb3
// cf.cond_br %c, ^bb2, ^bb3
// ^bb2:
// %x = "test.val1"() : () -> i64
// br ^bb4(%x : i64)
// cf.br ^bb4(%x : i64)
// ^bb3:
// %y = "test.val2"() : () -> i64
// br ^bb4(%y : i64)
// cf.br ^bb4(%y : i64)
// ^bb4(%z : i64):
// scf.yield %z : i64
// }
@ -184,13 +185,13 @@ struct SingleBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
// func @func_execute_region_elim() {
// "test.foo"() : () -> ()
// %c = "test.cmp"() : () -> i1
// cond_br %c, ^bb1, ^bb2
// cf.cond_br %c, ^bb1, ^bb2
// ^bb1: // pred: ^bb0
// %x = "test.val1"() : () -> i64
// br ^bb3(%x : i64)
// cf.br ^bb3(%x : i64)
// ^bb2: // pred: ^bb0
// %y = "test.val2"() : () -> i64
// br ^bb3(%y : i64)
// cf.br ^bb3(%y : i64)
// ^bb3(%z: i64): // 2 preds: ^bb1, ^bb2
// "test.bar"(%z) : (i64) -> ()
// return
@ -208,13 +209,13 @@ struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
Block *postBlock = rewriter.splitBlock(prevBlock, op->getIterator());
rewriter.setInsertionPointToEnd(prevBlock);
rewriter.create<BranchOp>(op.getLoc(), &op.getRegion().front());
rewriter.create<cf::BranchOp>(op.getLoc(), &op.getRegion().front());
for (Block &blk : op.getRegion()) {
if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) {
rewriter.setInsertionPoint(yieldOp);
rewriter.create<BranchOp>(yieldOp.getLoc(), postBlock,
yieldOp.getResults());
rewriter.create<cf::BranchOp>(yieldOp.getLoc(), postBlock,
yieldOp.getResults());
rewriter.eraseOp(yieldOp);
}
}

View File

@ -13,7 +13,7 @@ add_mlir_dialect_library(MLIRSparseTensorPipelines
MLIRMemRefToLLVM
MLIRPass
MLIRReconcileUnrealizedCasts
MLIRSCFToStandard
MLIRSCFToControlFlow
MLIRSparseTensor
MLIRSparseTensorTransforms
MLIRStandardOpsTransforms

View File

@ -33,7 +33,7 @@ void mlir::sparse_tensor::buildSparseCompiler(
pm.addNestedPass<FuncOp>(createLinalgBufferizePass());
pm.addNestedPass<FuncOp>(createConvertLinalgToLoopsPass());
pm.addNestedPass<FuncOp>(createConvertVectorToSCFPass());
pm.addPass(createLowerToCFGPass()); // --convert-scf-to-std
pm.addNestedPass<FuncOp>(createConvertSCFToCFPass());
pm.addPass(createFuncBufferizePass());
pm.addPass(arith::createConstantBufferizePass());
pm.addNestedPass<FuncOp>(createTensorBufferizePass());

View File

@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRStandard
MLIRArithmetic
MLIRCallInterfaces
MLIRCastInterfaces
MLIRControlFlow
MLIRControlFlowInterfaces
MLIRInferTypeOpInterface
MLIRIR

View File

@ -8,9 +8,8 @@
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
#include "mlir/Dialect/CommonFolders.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BlockAndValueMapping.h"
@ -77,7 +76,7 @@ struct StdInlinerInterface : public DialectInlinerInterface {
// Replace the return with a branch to the dest.
OpBuilder builder(op);
builder.create<BranchOp>(op->getLoc(), newDest, returnOp.getOperands());
builder.create<cf::BranchOp>(op->getLoc(), newDest, returnOp.getOperands());
op->erase();
}
@ -121,130 +120,6 @@ Operation *StandardOpsDialect::materializeConstant(OpBuilder &builder,
return nullptr;
}
//===----------------------------------------------------------------------===//
// AssertOp
//===----------------------------------------------------------------------===//
LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) {
// Erase assertion if argument is constant true.
if (matchPattern(op.getArg(), m_One())) {
rewriter.eraseOp(op);
return success();
}
return failure();
}
//===----------------------------------------------------------------------===//
// BranchOp
//===----------------------------------------------------------------------===//
/// Given a successor, try to collapse it to a new destination if it only
/// contains a passthrough unconditional branch. If the successor is
/// collapsable, `successor` and `successorOperands` are updated to reference
/// the new destination and values. `argStorage` is used as storage if operands
/// to the collapsed successor need to be remapped. It must outlive uses of
/// successorOperands.
static LogicalResult collapseBranch(Block *&successor,
ValueRange &successorOperands,
SmallVectorImpl<Value> &argStorage) {
// Check that the successor only contains a unconditional branch.
if (std::next(successor->begin()) != successor->end())
return failure();
// Check that the terminator is an unconditional branch.
BranchOp successorBranch = dyn_cast<BranchOp>(successor->getTerminator());
if (!successorBranch)
return failure();
// Check that the arguments are only used within the terminator.
for (BlockArgument arg : successor->getArguments()) {
for (Operation *user : arg.getUsers())
if (user != successorBranch)
return failure();
}
// Don't try to collapse branches to infinite loops.
Block *successorDest = successorBranch.getDest();
if (successorDest == successor)
return failure();
// Update the operands to the successor. If the branch parent has no
// arguments, we can use the branch operands directly.
OperandRange operands = successorBranch.getOperands();
if (successor->args_empty()) {
successor = successorDest;
successorOperands = operands;
return success();
}
// Otherwise, we need to remap any argument operands.
for (Value operand : operands) {
BlockArgument argOperand = operand.dyn_cast<BlockArgument>();
if (argOperand && argOperand.getOwner() == successor)
argStorage.push_back(successorOperands[argOperand.getArgNumber()]);
else
argStorage.push_back(operand);
}
successor = successorDest;
successorOperands = argStorage;
return success();
}
/// Simplify a branch to a block that has a single predecessor. This effectively
/// merges the two blocks.
static LogicalResult
simplifyBrToBlockWithSinglePred(BranchOp op, PatternRewriter &rewriter) {
// Check that the successor block has a single predecessor.
Block *succ = op.getDest();
Block *opParent = op->getBlock();
if (succ == opParent || !llvm::hasSingleElement(succ->getPredecessors()))
return failure();
// Merge the successor into the current block and erase the branch.
rewriter.mergeBlocks(succ, opParent, op.getOperands());
rewriter.eraseOp(op);
return success();
}
/// br ^bb1
/// ^bb1
/// br ^bbN(...)
///
/// -> br ^bbN(...)
///
static LogicalResult simplifyPassThroughBr(BranchOp op,
PatternRewriter &rewriter) {
Block *dest = op.getDest();
ValueRange destOperands = op.getOperands();
SmallVector<Value, 4> destOperandStorage;
// Try to collapse the successor if it points somewhere other than this
// block.
if (dest == op->getBlock() ||
failed(collapseBranch(dest, destOperands, destOperandStorage)))
return failure();
// Create a new branch with the collapsed successor.
rewriter.replaceOpWithNewOp<BranchOp>(op, dest, destOperands);
return success();
}
LogicalResult BranchOp::canonicalize(BranchOp op, PatternRewriter &rewriter) {
return success(succeeded(simplifyBrToBlockWithSinglePred(op, rewriter)) ||
succeeded(simplifyPassThroughBr(op, rewriter)));
}
void BranchOp::setDest(Block *block) { return setSuccessor(block); }
void BranchOp::eraseOperand(unsigned index) { (*this)->eraseOperand(index); }
Optional<MutableOperandRange>
BranchOp::getMutableSuccessorOperands(unsigned index) {
assert(index == 0 && "invalid successor index");
return getDestOperandsMutable();
}
Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) {
return getDest();
}
//===----------------------------------------------------------------------===//
// CallOp
//===----------------------------------------------------------------------===//
@ -307,260 +182,6 @@ LogicalResult CallIndirectOp::canonicalize(CallIndirectOp indirectCall,
return success();
}
//===----------------------------------------------------------------------===//
// CondBranchOp
//===----------------------------------------------------------------------===//
namespace {
/// cond_br true, ^bb1, ^bb2
/// -> br ^bb1
/// cond_br false, ^bb1, ^bb2
/// -> br ^bb2
///
struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
using OpRewritePattern<CondBranchOp>::OpRewritePattern;
LogicalResult matchAndRewrite(CondBranchOp condbr,
PatternRewriter &rewriter) const override {
if (matchPattern(condbr.getCondition(), m_NonZero())) {
// True branch taken.
rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
condbr.getTrueOperands());
return success();
}
if (matchPattern(condbr.getCondition(), m_Zero())) {
// False branch taken.
rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
condbr.getFalseOperands());
return success();
}
return failure();
}
};
/// cond_br %cond, ^bb1, ^bb2
/// ^bb1
/// br ^bbN(...)
/// ^bb2
/// br ^bbK(...)
///
/// -> cond_br %cond, ^bbN(...), ^bbK(...)
///
struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> {
using OpRewritePattern<CondBranchOp>::OpRewritePattern;
LogicalResult matchAndRewrite(CondBranchOp condbr,
PatternRewriter &rewriter) const override {
Block *trueDest = condbr.getTrueDest(), *falseDest = condbr.getFalseDest();
ValueRange trueDestOperands = condbr.getTrueOperands();
ValueRange falseDestOperands = condbr.getFalseOperands();
SmallVector<Value, 4> trueDestOperandStorage, falseDestOperandStorage;
// Try to collapse one of the current successors.
LogicalResult collapsedTrue =
collapseBranch(trueDest, trueDestOperands, trueDestOperandStorage);
LogicalResult collapsedFalse =
collapseBranch(falseDest, falseDestOperands, falseDestOperandStorage);
if (failed(collapsedTrue) && failed(collapsedFalse))
return failure();
// Create a new branch with the collapsed successors.
rewriter.replaceOpWithNewOp<CondBranchOp>(condbr, condbr.getCondition(),
trueDest, trueDestOperands,
falseDest, falseDestOperands);
return success();
}
};
/// cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N)
/// -> br ^bb1(A, ..., N)
///
/// cond_br %cond, ^bb1(A), ^bb1(B)
/// -> %select = arith.select %cond, A, B
/// br ^bb1(%select)
///
struct SimplifyCondBranchIdenticalSuccessors
: public OpRewritePattern<CondBranchOp> {
using OpRewritePattern<CondBranchOp>::OpRewritePattern;
LogicalResult matchAndRewrite(CondBranchOp condbr,
PatternRewriter &rewriter) const override {
// Check that the true and false destinations are the same and have the same
// operands.
Block *trueDest = condbr.getTrueDest();
if (trueDest != condbr.getFalseDest())
return failure();
// If all of the operands match, no selects need to be generated.
OperandRange trueOperands = condbr.getTrueOperands();
OperandRange falseOperands = condbr.getFalseOperands();
if (trueOperands == falseOperands) {
rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, trueOperands);
return success();
}
// Otherwise, if the current block is the only predecessor insert selects
// for any mismatched branch operands.
if (trueDest->getUniquePredecessor() != condbr->getBlock())
return failure();
// Generate a select for any operands that differ between the two.
SmallVector<Value, 8> mergedOperands;
mergedOperands.reserve(trueOperands.size());
Value condition = condbr.getCondition();
for (auto it : llvm::zip(trueOperands, falseOperands)) {
if (std::get<0>(it) == std::get<1>(it))
mergedOperands.push_back(std::get<0>(it));
else
mergedOperands.push_back(rewriter.create<arith::SelectOp>(
condbr.getLoc(), condition, std::get<0>(it), std::get<1>(it)));
}
rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, mergedOperands);
return success();
}
};
/// ...
/// cond_br %cond, ^bb1(...), ^bb2(...)
/// ...
/// ^bb1: // has single predecessor
/// ...
/// cond_br %cond, ^bb3(...), ^bb4(...)
///
/// ->
///
/// ...
/// cond_br %cond, ^bb1(...), ^bb2(...)
/// ...
/// ^bb1: // has single predecessor
/// ...
/// br ^bb3(...)
///
struct SimplifyCondBranchFromCondBranchOnSameCondition
: public OpRewritePattern<CondBranchOp> {
using OpRewritePattern<CondBranchOp>::OpRewritePattern;
LogicalResult matchAndRewrite(CondBranchOp condbr,
PatternRewriter &rewriter) const override {
// Check that we have a single distinct predecessor.
Block *currentBlock = condbr->getBlock();
Block *predecessor = currentBlock->getSinglePredecessor();
if (!predecessor)
return failure();
// Check that the predecessor terminates with a conditional branch to this
// block and that it branches on the same condition.
auto predBranch = dyn_cast<CondBranchOp>(predecessor->getTerminator());
if (!predBranch || condbr.getCondition() != predBranch.getCondition())
return failure();
// Fold this branch to an unconditional branch.
if (currentBlock == predBranch.getTrueDest())
rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
condbr.getTrueDestOperands());
else
rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
condbr.getFalseDestOperands());
return success();
}
};
/// cond_br %arg0, ^trueB, ^falseB
///
/// ^trueB:
/// "test.consumer1"(%arg0) : (i1) -> ()
/// ...
///
/// ^falseB:
/// "test.consumer2"(%arg0) : (i1) -> ()
/// ...
///
/// ->
///
/// cond_br %arg0, ^trueB, ^falseB
/// ^trueB:
/// "test.consumer1"(%true) : (i1) -> ()
/// ...
///
/// ^falseB:
/// "test.consumer2"(%false) : (i1) -> ()
/// ...
struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> {
using OpRewritePattern<CondBranchOp>::OpRewritePattern;
LogicalResult matchAndRewrite(CondBranchOp condbr,
PatternRewriter &rewriter) const override {
// Check that we have a single distinct predecessor.
bool replaced = false;
Type ty = rewriter.getI1Type();
// These variables serve to prevent creating duplicate constants
// and hold constant true or false values.
Value constantTrue = nullptr;
Value constantFalse = nullptr;
// TODO These checks can be expanded to encompas any use with only
// either the true of false edge as a predecessor. For now, we fall
// back to checking the single predecessor is given by the true/fasle
// destination, thereby ensuring that only that edge can reach the
// op.
if (condbr.getTrueDest()->getSinglePredecessor()) {
for (OpOperand &use :
llvm::make_early_inc_range(condbr.getCondition().getUses())) {
if (use.getOwner()->getBlock() == condbr.getTrueDest()) {
replaced = true;
if (!constantTrue)
constantTrue = rewriter.create<arith::ConstantOp>(
condbr.getLoc(), ty, rewriter.getBoolAttr(true));
rewriter.updateRootInPlace(use.getOwner(),
[&] { use.set(constantTrue); });
}
}
}
if (condbr.getFalseDest()->getSinglePredecessor()) {
for (OpOperand &use :
llvm::make_early_inc_range(condbr.getCondition().getUses())) {
if (use.getOwner()->getBlock() == condbr.getFalseDest()) {
replaced = true;
if (!constantFalse)
constantFalse = rewriter.create<arith::ConstantOp>(
condbr.getLoc(), ty, rewriter.getBoolAttr(false));
rewriter.updateRootInPlace(use.getOwner(),
[&] { use.set(constantFalse); });
}
}
}
return success(replaced);
}
};
} // namespace
void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
SimplifyCondBranchIdenticalSuccessors,
SimplifyCondBranchFromCondBranchOnSameCondition,
CondBranchTruthPropagation>(context);
}
Optional<MutableOperandRange>
CondBranchOp::getMutableSuccessorOperands(unsigned index) {
assert(index < getNumSuccessors() && "invalid successor index");
return index == trueIndex ? getTrueDestOperandsMutable()
: getFalseDestOperandsMutable();
}
Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
if (IntegerAttr condAttr = operands.front().dyn_cast_or_null<IntegerAttr>())
return condAttr.getValue().isOneValue() ? getTrueDest() : getFalseDest();
return nullptr;
}
//===----------------------------------------------------------------------===//
// ConstantOp
//===----------------------------------------------------------------------===//
@ -621,439 +242,6 @@ LogicalResult ReturnOp::verify() {
return success();
}
//===----------------------------------------------------------------------===//
// SwitchOp
//===----------------------------------------------------------------------===//
void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
Block *defaultDestination, ValueRange defaultOperands,
DenseIntElementsAttr caseValues,
BlockRange caseDestinations,
ArrayRef<ValueRange> caseOperands) {
build(builder, result, value, defaultOperands, caseOperands, caseValues,
defaultDestination, caseDestinations);
}
void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
Block *defaultDestination, ValueRange defaultOperands,
ArrayRef<APInt> caseValues, BlockRange caseDestinations,
ArrayRef<ValueRange> caseOperands) {
DenseIntElementsAttr caseValuesAttr;
if (!caseValues.empty()) {
ShapedType caseValueType = VectorType::get(
static_cast<int64_t>(caseValues.size()), value.getType());
caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
}
build(builder, result, value, defaultDestination, defaultOperands,
caseValuesAttr, caseDestinations, caseOperands);
}
/// <cases> ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)?
/// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )*
static ParseResult parseSwitchOpCases(
OpAsmParser &parser, Type &flagType, Block *&defaultDestination,
SmallVectorImpl<OpAsmParser::OperandType> &defaultOperands,
SmallVectorImpl<Type> &defaultOperandTypes,
DenseIntElementsAttr &caseValues,
SmallVectorImpl<Block *> &caseDestinations,
SmallVectorImpl<SmallVector<OpAsmParser::OperandType>> &caseOperands,
SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) {
if (parser.parseKeyword("default") || parser.parseColon() ||
parser.parseSuccessor(defaultDestination))
return failure();
if (succeeded(parser.parseOptionalLParen())) {
if (parser.parseRegionArgumentList(defaultOperands) ||
parser.parseColonTypeList(defaultOperandTypes) || parser.parseRParen())
return failure();
}
SmallVector<APInt> values;
unsigned bitWidth = flagType.getIntOrFloatBitWidth();
while (succeeded(parser.parseOptionalComma())) {
int64_t value = 0;
if (failed(parser.parseInteger(value)))
return failure();
values.push_back(APInt(bitWidth, value));
Block *destination;
SmallVector<OpAsmParser::OperandType> operands;
SmallVector<Type> operandTypes;
if (failed(parser.parseColon()) ||
failed(parser.parseSuccessor(destination)))
return failure();
if (succeeded(parser.parseOptionalLParen())) {
if (failed(parser.parseRegionArgumentList(operands)) ||
failed(parser.parseColonTypeList(operandTypes)) ||
failed(parser.parseRParen()))
return failure();
}
caseDestinations.push_back(destination);
caseOperands.emplace_back(operands);
caseOperandTypes.emplace_back(operandTypes);
}
if (!values.empty()) {
ShapedType caseValueType =
VectorType::get(static_cast<int64_t>(values.size()), flagType);
caseValues = DenseIntElementsAttr::get(caseValueType, values);
}
return success();
}
static void printSwitchOpCases(
OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination,
OperandRange defaultOperands, TypeRange defaultOperandTypes,
DenseIntElementsAttr caseValues, SuccessorRange caseDestinations,
OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes) {
p << " default: ";
p.printSuccessorAndUseList(defaultDestination, defaultOperands);
if (!caseValues)
return;
for (const auto &it : llvm::enumerate(caseValues.getValues<APInt>())) {
p << ',';
p.printNewline();
p << " ";
p << it.value().getLimitedValue();
p << ": ";
p.printSuccessorAndUseList(caseDestinations[it.index()],
caseOperands[it.index()]);
}
p.printNewline();
}
LogicalResult SwitchOp::verify() {
auto caseValues = getCaseValues();
auto caseDestinations = getCaseDestinations();
if (!caseValues && caseDestinations.empty())
return success();
Type flagType = getFlag().getType();
Type caseValueType = caseValues->getType().getElementType();
if (caseValueType != flagType)
return emitOpError() << "'flag' type (" << flagType
<< ") should match case value type (" << caseValueType
<< ")";
if (caseValues &&
caseValues->size() != static_cast<int64_t>(caseDestinations.size()))
return emitOpError() << "number of case values (" << caseValues->size()
<< ") should match number of "
"case destinations ("
<< caseDestinations.size() << ")";
return success();
}
Optional<MutableOperandRange>
SwitchOp::getMutableSuccessorOperands(unsigned index) {
assert(index < getNumSuccessors() && "invalid successor index");
return index == 0 ? getDefaultOperandsMutable()
: getCaseOperandsMutable(index - 1);
}
Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
Optional<DenseIntElementsAttr> caseValues = getCaseValues();
if (!caseValues)
return getDefaultDestination();
SuccessorRange caseDests = getCaseDestinations();
if (auto value = operands.front().dyn_cast_or_null<IntegerAttr>()) {
for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>()))
if (it.value() == value.getValue())
return caseDests[it.index()];
return getDefaultDestination();
}
return nullptr;
}
/// switch %flag : i32, [
/// default: ^bb1
/// ]
/// -> br ^bb1
static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op,
PatternRewriter &rewriter) {
if (!op.getCaseDestinations().empty())
return failure();
rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
op.getDefaultOperands());
return success();
}
/// switch %flag : i32, [
/// default: ^bb1,
/// 42: ^bb1,
/// 43: ^bb2
/// ]
/// ->
/// switch %flag : i32, [
/// default: ^bb1,
/// 43: ^bb2
/// ]
static LogicalResult
dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter) {
SmallVector<Block *> newCaseDestinations;
SmallVector<ValueRange> newCaseOperands;
SmallVector<APInt> newCaseValues;
bool requiresChange = false;
auto caseValues = op.getCaseValues();
auto caseDests = op.getCaseDestinations();
for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
if (caseDests[it.index()] == op.getDefaultDestination() &&
op.getCaseOperands(it.index()) == op.getDefaultOperands()) {
requiresChange = true;
continue;
}
newCaseDestinations.push_back(caseDests[it.index()]);
newCaseOperands.push_back(op.getCaseOperands(it.index()));
newCaseValues.push_back(it.value());
}
if (!requiresChange)
return failure();
rewriter.replaceOpWithNewOp<SwitchOp>(
op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
newCaseValues, newCaseDestinations, newCaseOperands);
return success();
}
/// Helper for folding a switch with a constant value.
/// switch %c_42 : i32, [
/// default: ^bb1 ,
/// 42: ^bb2,
/// 43: ^bb3
/// ]
/// -> br ^bb2
static void foldSwitch(SwitchOp op, PatternRewriter &rewriter,
const APInt &caseValue) {
auto caseValues = op.getCaseValues();
for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
if (it.value() == caseValue) {
rewriter.replaceOpWithNewOp<BranchOp>(
op, op.getCaseDestinations()[it.index()],
op.getCaseOperands(it.index()));
return;
}
}
rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
op.getDefaultOperands());
}
/// switch %c_42 : i32, [
/// default: ^bb1,
/// 42: ^bb2,
/// 43: ^bb3
/// ]
/// -> br ^bb2
static LogicalResult simplifyConstSwitchValue(SwitchOp op,
PatternRewriter &rewriter) {
APInt caseValue;
if (!matchPattern(op.getFlag(), m_ConstantInt(&caseValue)))
return failure();
foldSwitch(op, rewriter, caseValue);
return success();
}
/// switch %c_42 : i32, [
/// default: ^bb1,
/// 42: ^bb2,
/// ]
/// ^bb2:
/// br ^bb3
/// ->
/// switch %c_42 : i32, [
/// default: ^bb1,
/// 42: ^bb3,
/// ]
static LogicalResult simplifyPassThroughSwitch(SwitchOp op,
PatternRewriter &rewriter) {
SmallVector<Block *> newCaseDests;
SmallVector<ValueRange> newCaseOperands;
SmallVector<SmallVector<Value>> argStorage;
auto caseValues = op.getCaseValues();
auto caseDests = op.getCaseDestinations();
bool requiresChange = false;
for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
Block *caseDest = caseDests[i];
ValueRange caseOperands = op.getCaseOperands(i);
argStorage.emplace_back();
if (succeeded(collapseBranch(caseDest, caseOperands, argStorage.back())))
requiresChange = true;
newCaseDests.push_back(caseDest);
newCaseOperands.push_back(caseOperands);
}
Block *defaultDest = op.getDefaultDestination();
ValueRange defaultOperands = op.getDefaultOperands();
argStorage.emplace_back();
if (succeeded(
collapseBranch(defaultDest, defaultOperands, argStorage.back())))
requiresChange = true;
if (!requiresChange)
return failure();
rewriter.replaceOpWithNewOp<SwitchOp>(op, op.getFlag(), defaultDest,
defaultOperands, caseValues.getValue(),
newCaseDests, newCaseOperands);
return success();
}
/// switch %flag : i32, [
/// default: ^bb1,
/// 42: ^bb2,
/// ]
/// ^bb2:
/// switch %flag : i32, [
/// default: ^bb3,
/// 42: ^bb4
/// ]
/// ->
/// switch %flag : i32, [
/// default: ^bb1,
/// 42: ^bb2,
/// ]
/// ^bb2:
/// br ^bb4
///
/// and
///
/// switch %flag : i32, [
/// default: ^bb1,
/// 42: ^bb2,
/// ]
/// ^bb2:
/// switch %flag : i32, [
/// default: ^bb3,
/// 43: ^bb4
/// ]
/// ->
/// switch %flag : i32, [
/// default: ^bb1,
/// 42: ^bb2,
/// ]
/// ^bb2:
/// br ^bb3
static LogicalResult
simplifySwitchFromSwitchOnSameCondition(SwitchOp op,
PatternRewriter &rewriter) {
// Check that we have a single distinct predecessor.
Block *currentBlock = op->getBlock();
Block *predecessor = currentBlock->getSinglePredecessor();
if (!predecessor)
return failure();
// Check that the predecessor terminates with a switch branch to this block
// and that it branches on the same condition and that this branch isn't the
// default destination.
auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
predSwitch.getDefaultDestination() == currentBlock)
return failure();
// Fold this switch to an unconditional branch.
SuccessorRange predDests = predSwitch.getCaseDestinations();
auto it = llvm::find(predDests, currentBlock);
if (it != predDests.end()) {
Optional<DenseIntElementsAttr> predCaseValues = predSwitch.getCaseValues();
foldSwitch(op, rewriter,
predCaseValues->getValues<APInt>()[it - predDests.begin()]);
} else {
rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
op.getDefaultOperands());
}
return success();
}
/// switch %flag : i32, [
/// default: ^bb1,
/// 42: ^bb2
/// ]
/// ^bb1:
/// switch %flag : i32, [
/// default: ^bb3,
/// 42: ^bb4,
/// 43: ^bb5
/// ]
/// ->
/// switch %flag : i32, [
/// default: ^bb1,
/// 42: ^bb2,
/// ]
/// ^bb1:
/// switch %flag : i32, [
/// default: ^bb3,
/// 43: ^bb5
/// ]
static LogicalResult
simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op,
PatternRewriter &rewriter) {
// Check that we have a single distinct predecessor.
Block *currentBlock = op->getBlock();
Block *predecessor = currentBlock->getSinglePredecessor();
if (!predecessor)
return failure();
// Check that the predecessor terminates with a switch branch to this block
// and that it branches on the same condition and that this branch is the
// default destination.
auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
predSwitch.getDefaultDestination() != currentBlock)
return failure();
// Delete case values that are not possible here.
DenseSet<APInt> caseValuesToRemove;
auto predDests = predSwitch.getCaseDestinations();
auto predCaseValues = predSwitch.getCaseValues();
for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i)
if (currentBlock != predDests[i])
caseValuesToRemove.insert(predCaseValues->getValues<APInt>()[i]);
SmallVector<Block *> newCaseDestinations;
SmallVector<ValueRange> newCaseOperands;
SmallVector<APInt> newCaseValues;
bool requiresChange = false;
auto caseValues = op.getCaseValues();
auto caseDests = op.getCaseDestinations();
for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
if (caseValuesToRemove.contains(it.value())) {
requiresChange = true;
continue;
}
newCaseDestinations.push_back(caseDests[it.index()]);
newCaseOperands.push_back(op.getCaseOperands(it.index()));
newCaseValues.push_back(it.value());
}
if (!requiresChange)
return failure();
rewriter.replaceOpWithNewOp<SwitchOp>(
op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
newCaseValues, newCaseDestinations, newCaseOperands);
return success();
}
void SwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add(&simplifySwitchWithOnlyDefault)
.add(&dropSwitchCasesThatMatchDefault)
.add(&simplifyConstSwitchValue)
.add(&simplifyPassThroughSwitch)
.add(&simplifySwitchFromSwitchOnSameCondition)
.add(&simplifySwitchFromDefaultSwitchOnSameCondition);
}
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//

View File

@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/SCF/SCF.h"
@ -41,6 +42,7 @@ void registerToCppTranslation() {
[](DialectRegistry &registry) {
// clang-format off
registry.insert<arith::ArithmeticDialect,
cf::ControlFlowDialect,
emitc::EmitCDialect,
math::MathDialect,
StandardOpsDialect,

View File

@ -6,8 +6,7 @@
//
//===----------------------------------------------------------------------===//
#include <utility>
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
@ -23,6 +22,7 @@
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/FormatVariadic.h"
#include <utility>
#define DEBUG_TYPE "translate-to-cpp"
@ -237,7 +237,8 @@ static LogicalResult printOperation(CppEmitter &emitter,
return printConstantOp(emitter, operation, value);
}
static LogicalResult printOperation(CppEmitter &emitter, BranchOp branchOp) {
static LogicalResult printOperation(CppEmitter &emitter,
cf::BranchOp branchOp) {
raw_ostream &os = emitter.ostream();
Block &successor = *branchOp.getSuccessor();
@ -257,7 +258,7 @@ static LogicalResult printOperation(CppEmitter &emitter, BranchOp branchOp) {
}
static LogicalResult printOperation(CppEmitter &emitter,
CondBranchOp condBranchOp) {
cf::CondBranchOp condBranchOp) {
raw_indented_ostream &os = emitter.ostream();
Block &trueSuccessor = *condBranchOp.getTrueDest();
Block &falseSuccessor = *condBranchOp.getFalseDest();
@ -637,11 +638,12 @@ static LogicalResult printOperation(CppEmitter &emitter, FuncOp functionOp) {
return failure();
}
for (Operation &op : block.getOperations()) {
// When generating code for an scf.if or std.cond_br op no semicolon needs
// When generating code for an scf.if or cf.cond_br op no semicolon needs
// to be printed after the closing brace.
// When generating code for an scf.for op, printing a trailing semicolon
// is handled within the printOperation function.
bool trailingSemicolon = !isa<scf::IfOp, scf::ForOp, CondBranchOp>(op);
bool trailingSemicolon =
!isa<scf::IfOp, scf::ForOp, cf::CondBranchOp>(op);
if (failed(emitter.emitOperation(
op, /*trailingSemicolon=*/trailingSemicolon)))
@ -907,8 +909,8 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
.Case<scf::ForOp, scf::IfOp, scf::YieldOp>(
[&](auto op) { return printOperation(*this, op); })
// Standard ops.
.Case<BranchOp, mlir::CallOp, CondBranchOp, mlir::ConstantOp, FuncOp,
ModuleOp, ReturnOp>(
.Case<cf::BranchOp, mlir::CallOp, cf::CondBranchOp, mlir::ConstantOp,
FuncOp, ModuleOp, ReturnOp>(
[&](auto op) { return printOperation(*this, op); })
// Arithmetic ops.
.Case<arith::ConstantOp>(

View File

@ -52,10 +52,10 @@ func @control_flow(%arg: memref<2xf32>, %cond: i1) attributes {test.ptr = "func"
%1 = memref.alloca() {test.ptr = "alloca_2"} : memref<8x64xf32>
%2 = memref.alloc() {test.ptr = "alloc_1"} : memref<8x64xf32>
cond_br %cond, ^bb1(%0 : memref<8x64xf32>), ^bb2(%0 : memref<8x64xf32>)
cf.cond_br %cond, ^bb1(%0 : memref<8x64xf32>), ^bb2(%0 : memref<8x64xf32>)
^bb1(%arg1: memref<8x64xf32>):
br ^bb2(%arg1 : memref<8x64xf32>)
cf.br ^bb2(%arg1 : memref<8x64xf32>)
^bb2(%arg2: memref<8x64xf32>):
return
@ -85,10 +85,10 @@ func @control_flow_merge(%arg: memref<2xf32>, %cond: i1) attributes {test.ptr =
%1 = memref.alloca() {test.ptr = "alloca_2"} : memref<8x64xf32>
%2 = memref.alloc() {test.ptr = "alloc_1"} : memref<8x64xf32>
cond_br %cond, ^bb1(%0 : memref<8x64xf32>), ^bb2(%2 : memref<8x64xf32>)
cf.cond_br %cond, ^bb1(%0 : memref<8x64xf32>), ^bb2(%2 : memref<8x64xf32>)
^bb1(%arg1: memref<8x64xf32>):
br ^bb2(%arg1 : memref<8x64xf32>)
cf.br ^bb2(%arg1 : memref<8x64xf32>)
^bb2(%arg2: memref<8x64xf32>):
return

View File

@ -2,11 +2,11 @@
// CHECK-LABEL: Testing : func_condBranch
func @func_condBranch(%cond : i1) {
cond_br %cond, ^bb1, ^bb2
cf.cond_br %cond, ^bb1, ^bb2
^bb1:
br ^exit
cf.br ^exit
^bb2:
br ^exit
cf.br ^exit
^exit:
return
}
@ -49,14 +49,14 @@ func @func_condBranch(%cond : i1) {
// CHECK-LABEL: Testing : func_loop
func @func_loop(%arg0 : i32, %arg1 : i32) {
br ^loopHeader(%arg0 : i32)
cf.br ^loopHeader(%arg0 : i32)
^loopHeader(%counter : i32):
%lessThan = arith.cmpi slt, %counter, %arg1 : i32
cond_br %lessThan, ^loopBody, ^exit
cf.cond_br %lessThan, ^loopBody, ^exit
^loopBody:
%const0 = arith.constant 1 : i32
%inc = arith.addi %counter, %const0 : i32
br ^loopHeader(%inc : i32)
cf.br ^loopHeader(%inc : i32)
^exit:
return
}
@ -153,17 +153,17 @@ func @func_loop_nested_region(
%arg2 : index,
%arg3 : index,
%arg4 : index) {
br ^loopHeader(%arg0 : i32)
cf.br ^loopHeader(%arg0 : i32)
^loopHeader(%counter : i32):
%lessThan = arith.cmpi slt, %counter, %arg1 : i32
cond_br %lessThan, ^loopBody, ^exit
cf.cond_br %lessThan, ^loopBody, ^exit
^loopBody:
%const0 = arith.constant 1 : i32
%inc = arith.addi %counter, %const0 : i32
scf.for %arg5 = %arg2 to %arg3 step %arg4 {
scf.for %arg6 = %arg2 to %arg3 step %arg4 { }
}
br ^loopHeader(%inc : i32)
cf.br ^loopHeader(%inc : i32)
^exit:
return
}

View File

@ -19,7 +19,7 @@ func @func_simpleBranch(%arg0: i32, %arg1 : i32) -> i32 {
// CHECK-NEXT: LiveOut: arg0@0 arg1@0
// CHECK-NEXT: BeginLiveness
// CHECK-NEXT: EndLiveness
br ^exit
cf.br ^exit
^exit:
// CHECK: Block: 1
// CHECK-NEXT: LiveIn: arg0@0 arg1@0
@ -42,17 +42,17 @@ func @func_condBranch(%cond : i1, %arg1: i32, %arg2 : i32) -> i32 {
// CHECK-NEXT: LiveOut: arg1@0 arg2@0
// CHECK-NEXT: BeginLiveness
// CHECK-NEXT: EndLiveness
cond_br %cond, ^bb1, ^bb2
cf.cond_br %cond, ^bb1, ^bb2
^bb1:
// CHECK: Block: 1
// CHECK-NEXT: LiveIn: arg1@0 arg2@0
// CHECK-NEXT: LiveOut: arg1@0 arg2@0
br ^exit
cf.br ^exit
^bb2:
// CHECK: Block: 2
// CHECK-NEXT: LiveIn: arg1@0 arg2@0
// CHECK-NEXT: LiveOut: arg1@0 arg2@0
br ^exit
cf.br ^exit
^exit:
// CHECK: Block: 3
// CHECK-NEXT: LiveIn: arg1@0 arg2@0
@ -74,7 +74,7 @@ func @func_loop(%arg0 : i32, %arg1 : i32) -> i32 {
// CHECK-NEXT: LiveIn:{{ *$}}
// CHECK-NEXT: LiveOut: arg1@0
%const0 = arith.constant 0 : i32
br ^loopHeader(%const0, %arg0 : i32, i32)
cf.br ^loopHeader(%const0, %arg0 : i32, i32)
^loopHeader(%counter : i32, %i : i32):
// CHECK: Block: 1
// CHECK-NEXT: LiveIn: arg1@0
@ -82,10 +82,10 @@ func @func_loop(%arg0 : i32, %arg1 : i32) -> i32 {
// CHECK-NEXT: BeginLiveness
// CHECK-NEXT: val_5
// CHECK-NEXT: %2 = arith.cmpi
// CHECK-NEXT: cond_br
// CHECK-NEXT: cf.cond_br
// CHECK-NEXT: EndLiveness
%lessThan = arith.cmpi slt, %counter, %arg1 : i32
cond_br %lessThan, ^loopBody(%i : i32), ^exit(%i : i32)
cf.cond_br %lessThan, ^loopBody(%i : i32), ^exit(%i : i32)
^loopBody(%val : i32):
// CHECK: Block: 2
// CHECK-NEXT: LiveIn: arg1@0 arg0@1
@ -98,12 +98,12 @@ func @func_loop(%arg0 : i32, %arg1 : i32) -> i32 {
// CHECK-NEXT: val_8
// CHECK-NEXT: %4 = arith.addi
// CHECK-NEXT: %5 = arith.addi
// CHECK-NEXT: br
// CHECK-NEXT: cf.br
// CHECK: EndLiveness
%const1 = arith.constant 1 : i32
%inc = arith.addi %val, %const1 : i32
%inc2 = arith.addi %counter, %const1 : i32
br ^loopHeader(%inc, %inc2 : i32, i32)
cf.br ^loopHeader(%inc, %inc2 : i32, i32)
^exit(%sum : i32):
// CHECK: Block: 3
// CHECK-NEXT: LiveIn: arg1@0
@ -147,14 +147,14 @@ func @func_ranges(%cond : i1, %arg1 : i32, %arg2 : i32, %arg3 : i32) -> i32 {
// CHECK-NEXT: val_9
// CHECK-NEXT: %4 = arith.muli
// CHECK-NEXT: %5 = arith.addi
// CHECK-NEXT: cond_br
// CHECK-NEXT: cf.cond_br
// CHECK-NEXT: %c
// CHECK-NEXT: %6 = arith.muli
// CHECK-NEXT: %7 = arith.muli
// CHECK-NEXT: %8 = arith.addi
// CHECK-NEXT: val_10
// CHECK-NEXT: %5 = arith.addi
// CHECK-NEXT: cond_br
// CHECK-NEXT: cf.cond_br
// CHECK-NEXT: %7
// CHECK: EndLiveness
%0 = arith.addi %arg1, %arg2 : i32
@ -164,7 +164,7 @@ func @func_ranges(%cond : i1, %arg1 : i32, %arg2 : i32, %arg3 : i32) -> i32 {
%3 = arith.muli %0, %1 : i32
%4 = arith.muli %3, %2 : i32
%5 = arith.addi %4, %const1 : i32
cond_br %cond, ^bb1, ^bb2
cf.cond_br %cond, ^bb1, ^bb2
^bb1:
// CHECK: Block: 1
@ -172,7 +172,7 @@ func @func_ranges(%cond : i1, %arg1 : i32, %arg2 : i32, %arg3 : i32) -> i32 {
// CHECK-NEXT: LiveOut: arg2@0
%const4 = arith.constant 4 : i32
%6 = arith.muli %4, %const4 : i32
br ^exit(%6 : i32)
cf.br ^exit(%6 : i32)
^bb2:
// CHECK: Block: 2
@ -180,7 +180,7 @@ func @func_ranges(%cond : i1, %arg1 : i32, %arg2 : i32, %arg3 : i32) -> i32 {
// CHECK-NEXT: LiveOut: arg2@0
%7 = arith.muli %4, %5 : i32
%8 = arith.addi %4, %arg2 : i32
br ^exit(%8 : i32)
cf.br ^exit(%8 : i32)
^exit(%sum : i32):
// CHECK: Block: 3
@ -284,7 +284,7 @@ func @nested_region3(
// CHECK-NEXT: %0 = arith.addi
// CHECK-NEXT: %1 = arith.addi
// CHECK-NEXT: scf.for
// CHECK: // br ^bb1
// CHECK: // cf.br ^bb1
// CHECK-NEXT: %2 = arith.addi
// CHECK-NEXT: scf.for
// CHECK: // %2 = arith.addi
@ -301,7 +301,7 @@ func @nested_region3(
%2 = arith.addi %0, %arg5 : i32
memref.store %2, %buffer[] : memref<i32>
}
br ^exit
cf.br ^exit
^exit:
// CHECK: Block: 2

View File

@ -1531,10 +1531,10 @@ int registerOnlyStd() {
fprintf(stderr, "@registration\n");
// CHECK-LABEL: @registration
// CHECK: std.cond_br is_registered: 1
fprintf(stderr, "std.cond_br is_registered: %d\n",
// CHECK: cf.cond_br is_registered: 1
fprintf(stderr, "cf.cond_br is_registered: %d\n",
mlirContextIsRegisteredOperation(
ctx, mlirStringRefCreateFromCString("std.cond_br")));
ctx, mlirStringRefCreateFromCString("cf.cond_br")));
// CHECK: std.not_existing_op is_registered: 0
fprintf(stderr, "std.not_existing_op is_registered: %d\n",

View File

@ -27,7 +27,7 @@ func @execute_no_async_args(%arg0: f32, %arg1: memref<1xf32>) {
// CHECK: %[[IS_ERROR:.*]] = call @mlirAsyncRuntimeIsTokenError(%[[TOKEN]])
// CHECK: %[[TRUE:.*]] = arith.constant true
// CHECK: %[[NOT_ERROR:.*]] = arith.xori %[[IS_ERROR]], %[[TRUE]] : i1
// CHECK: assert %[[NOT_ERROR]]
// CHECK: cf.assert %[[NOT_ERROR]]
// CHECK-NEXT: return
async.await %token : !async.token
return
@ -90,7 +90,7 @@ func @nested_async_execute(%arg0: f32, %arg1: f32, %arg2: memref<1xf32>) {
// CHECK: %[[IS_ERROR:.*]] = call @mlirAsyncRuntimeIsTokenError(%[[TOKEN]])
// CHECK: %[[TRUE:.*]] = arith.constant true
// CHECK: %[[NOT_ERROR:.*]] = arith.xori %[[IS_ERROR]], %[[TRUE]] : i1
// CHECK: assert %[[NOT_ERROR]]
// CHECK: cf.assert %[[NOT_ERROR]]
async.await %token0 : !async.token
return
}

View File

@ -0,0 +1,41 @@
// RUN: mlir-opt -split-input-file -convert-std-to-spirv -verify-diagnostics %s | FileCheck %s
//===----------------------------------------------------------------------===//
// cf.br, cf.cond_br
//===----------------------------------------------------------------------===//
module attributes {
spv.target_env = #spv.target_env<#spv.vce<v1.0, [], []>, {}>
} {
// CHECK-LABEL: func @simple_loop
func @simple_loop(index, index, index) {
^bb0(%begin : index, %end : index, %step : index):
// CHECK-NEXT: spv.Branch ^bb1
cf.br ^bb1
// CHECK-NEXT: ^bb1: // pred: ^bb0
// CHECK-NEXT: spv.Branch ^bb2({{.*}} : i32)
^bb1: // pred: ^bb0
cf.br ^bb2(%begin : index)
// CHECK: ^bb2({{.*}}: i32): // 2 preds: ^bb1, ^bb3
// CHECK-NEXT: {{.*}} = spv.SLessThan {{.*}}, {{.*}} : i32
// CHECK-NEXT: spv.BranchConditional {{.*}}, ^bb3, ^bb4
^bb2(%0: index): // 2 preds: ^bb1, ^bb3
%1 = arith.cmpi slt, %0, %end : index
cf.cond_br %1, ^bb3, ^bb4
// CHECK: ^bb3: // pred: ^bb2
// CHECK-NEXT: {{.*}} = spv.IAdd {{.*}}, {{.*}} : i32
// CHECK-NEXT: spv.Branch ^bb2({{.*}} : i32)
^bb3: // pred: ^bb2
%2 = arith.addi %0, %step : index
cf.br ^bb2(%2 : index)
// CHECK: ^bb4: // pred: ^bb2
^bb4: // pred: ^bb2
return
}
}

View File

@ -168,16 +168,16 @@ gpu.module @test_module {
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%0 = gpu.subgroup_mma_load_matrix %arg2[%c0, %c0] {leadDimension = 128 : index} : memref<128x128xf16> -> !gpu.mma_matrix<16x16xf16, "COp">
br ^bb1(%c0, %0 : index, !gpu.mma_matrix<16x16xf16, "COp">)
cf.br ^bb1(%c0, %0 : index, !gpu.mma_matrix<16x16xf16, "COp">)
^bb1(%1: index, %2: !gpu.mma_matrix<16x16xf16, "COp">): // 2 preds: ^bb0, ^bb2
%3 = arith.cmpi slt, %1, %c128 : index
cond_br %3, ^bb2, ^bb3
cf.cond_br %3, ^bb2, ^bb3
^bb2: // pred: ^bb1
%4 = gpu.subgroup_mma_load_matrix %arg0[%c0, %1] {leadDimension = 128 : index} : memref<128x128xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
%5 = gpu.subgroup_mma_load_matrix %arg1[%1, %c0] {leadDimension = 128 : index} : memref<128x128xf16> -> !gpu.mma_matrix<16x16xf16, "BOp">
%6 = gpu.subgroup_mma_compute %4, %5, %2 : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
%7 = arith.addi %1, %c32 : index
br ^bb1(%7, %6 : index, !gpu.mma_matrix<16x16xf16, "COp">)
cf.br ^bb1(%7, %6 : index, !gpu.mma_matrix<16x16xf16, "COp">)
^bb3: // pred: ^bb1
gpu.subgroup_mma_store_matrix %2, %arg2[%c0, %c0] {leadDimension = 128 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<128x128xf16>
return

View File

@ -22,17 +22,17 @@ func @branch_loop() {
// CHECK: omp.parallel
omp.parallel {
// CHECK-NEXT: llvm.br ^[[BB1:.*]](%{{[0-9]+}}, %{{[0-9]+}} : i64, i64
br ^bb1(%start, %end : index, index)
cf.br ^bb1(%start, %end : index, index)
// CHECK-NEXT: ^[[BB1]](%[[ARG1:[0-9]+]]: i64, %[[ARG2:[0-9]+]]: i64):{{.*}}
^bb1(%0: index, %1: index):
// CHECK-NEXT: %[[CMP:[0-9]+]] = llvm.icmp "slt" %[[ARG1]], %[[ARG2]] : i64
%2 = arith.cmpi slt, %0, %1 : index
// CHECK-NEXT: llvm.cond_br %[[CMP]], ^[[BB2:.*]](%{{[0-9]+}}, %{{[0-9]+}} : i64, i64), ^[[BB3:.*]]
cond_br %2, ^bb2(%end, %end : index, index), ^bb3
cf.cond_br %2, ^bb2(%end, %end : index, index), ^bb3
// CHECK-NEXT: ^[[BB2]](%[[ARG3:[0-9]+]]: i64, %[[ARG4:[0-9]+]]: i64):
^bb2(%3: index, %4: index):
// CHECK-NEXT: llvm.br ^[[BB1]](%[[ARG3]], %[[ARG4]] : i64, i64)
br ^bb1(%3, %4 : index, index)
cf.br ^bb1(%3, %4 : index, index)
// CHECK-NEXT: ^[[BB3]]:
^bb3:
omp.flush

View File

@ -1,14 +1,14 @@
// RUN: mlir-opt -allow-unregistered-dialect -convert-scf-to-std %s | FileCheck %s
// RUN: mlir-opt -allow-unregistered-dialect -convert-scf-to-cf %s | FileCheck %s
// CHECK-LABEL: func @simple_std_for_loop(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index) {
// CHECK-NEXT: br ^bb1(%{{.*}} : index)
// CHECK-NEXT: cf.br ^bb1(%{{.*}} : index)
// CHECK-NEXT: ^bb1(%{{.*}}: index): // 2 preds: ^bb0, ^bb2
// CHECK-NEXT: %{{.*}} = arith.cmpi slt, %{{.*}}, %{{.*}} : index
// CHECK-NEXT: cond_br %{{.*}}, ^bb2, ^bb3
// CHECK-NEXT: cf.cond_br %{{.*}}, ^bb2, ^bb3
// CHECK-NEXT: ^bb2: // pred: ^bb1
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index
// CHECK-NEXT: %[[iv:.*]] = arith.addi %{{.*}}, %{{.*}} : index
// CHECK-NEXT: br ^bb1(%[[iv]] : index)
// CHECK-NEXT: cf.br ^bb1(%[[iv]] : index)
// CHECK-NEXT: ^bb3: // pred: ^bb1
// CHECK-NEXT: return
func @simple_std_for_loop(%arg0 : index, %arg1 : index, %arg2 : index) {
@ -19,23 +19,23 @@ func @simple_std_for_loop(%arg0 : index, %arg1 : index, %arg2 : index) {
}
// CHECK-LABEL: func @simple_std_2_for_loops(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index) {
// CHECK-NEXT: br ^bb1(%{{.*}} : index)
// CHECK-NEXT: cf.br ^bb1(%{{.*}} : index)
// CHECK-NEXT: ^bb1(%[[ub0:.*]]: index): // 2 preds: ^bb0, ^bb5
// CHECK-NEXT: %[[cond0:.*]] = arith.cmpi slt, %[[ub0]], %{{.*}} : index
// CHECK-NEXT: cond_br %[[cond0]], ^bb2, ^bb6
// CHECK-NEXT: cf.cond_br %[[cond0]], ^bb2, ^bb6
// CHECK-NEXT: ^bb2: // pred: ^bb1
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index
// CHECK-NEXT: br ^bb3(%{{.*}} : index)
// CHECK-NEXT: cf.br ^bb3(%{{.*}} : index)
// CHECK-NEXT: ^bb3(%[[ub1:.*]]: index): // 2 preds: ^bb2, ^bb4
// CHECK-NEXT: %[[cond1:.*]] = arith.cmpi slt, %{{.*}}, %{{.*}} : index
// CHECK-NEXT: cond_br %[[cond1]], ^bb4, ^bb5
// CHECK-NEXT: cf.cond_br %[[cond1]], ^bb4, ^bb5
// CHECK-NEXT: ^bb4: // pred: ^bb3
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index
// CHECK-NEXT: %[[iv1:.*]] = arith.addi %{{.*}}, %{{.*}} : index
// CHECK-NEXT: br ^bb3(%[[iv1]] : index)
// CHECK-NEXT: cf.br ^bb3(%[[iv1]] : index)
// CHECK-NEXT: ^bb5: // pred: ^bb3
// CHECK-NEXT: %[[iv0:.*]] = arith.addi %{{.*}}, %{{.*}} : index
// CHECK-NEXT: br ^bb1(%[[iv0]] : index)
// CHECK-NEXT: cf.br ^bb1(%[[iv0]] : index)
// CHECK-NEXT: ^bb6: // pred: ^bb1
// CHECK-NEXT: return
func @simple_std_2_for_loops(%arg0 : index, %arg1 : index, %arg2 : index) {
@ -49,10 +49,10 @@ func @simple_std_2_for_loops(%arg0 : index, %arg1 : index, %arg2 : index) {
}
// CHECK-LABEL: func @simple_std_if(%{{.*}}: i1) {
// CHECK-NEXT: cond_br %{{.*}}, ^bb1, ^bb2
// CHECK-NEXT: cf.cond_br %{{.*}}, ^bb1, ^bb2
// CHECK-NEXT: ^bb1: // pred: ^bb0
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index
// CHECK-NEXT: br ^bb2
// CHECK-NEXT: cf.br ^bb2
// CHECK-NEXT: ^bb2: // 2 preds: ^bb0, ^bb1
// CHECK-NEXT: return
func @simple_std_if(%arg0: i1) {
@ -63,13 +63,13 @@ func @simple_std_if(%arg0: i1) {
}
// CHECK-LABEL: func @simple_std_if_else(%{{.*}}: i1) {
// CHECK-NEXT: cond_br %{{.*}}, ^bb1, ^bb2
// CHECK-NEXT: cf.cond_br %{{.*}}, ^bb1, ^bb2
// CHECK-NEXT: ^bb1: // pred: ^bb0
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index
// CHECK-NEXT: br ^bb3
// CHECK-NEXT: cf.br ^bb3
// CHECK-NEXT: ^bb2: // pred: ^bb0
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index
// CHECK-NEXT: br ^bb3
// CHECK-NEXT: cf.br ^bb3
// CHECK-NEXT: ^bb3: // 2 preds: ^bb1, ^bb2
// CHECK-NEXT: return
func @simple_std_if_else(%arg0: i1) {
@ -82,18 +82,18 @@ func @simple_std_if_else(%arg0: i1) {
}
// CHECK-LABEL: func @simple_std_2_ifs(%{{.*}}: i1) {
// CHECK-NEXT: cond_br %{{.*}}, ^bb1, ^bb5
// CHECK-NEXT: cf.cond_br %{{.*}}, ^bb1, ^bb5
// CHECK-NEXT: ^bb1: // pred: ^bb0
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index
// CHECK-NEXT: cond_br %{{.*}}, ^bb2, ^bb3
// CHECK-NEXT: cf.cond_br %{{.*}}, ^bb2, ^bb3
// CHECK-NEXT: ^bb2: // pred: ^bb1
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index
// CHECK-NEXT: br ^bb4
// CHECK-NEXT: cf.br ^bb4
// CHECK-NEXT: ^bb3: // pred: ^bb1
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index
// CHECK-NEXT: br ^bb4
// CHECK-NEXT: cf.br ^bb4
// CHECK-NEXT: ^bb4: // 2 preds: ^bb2, ^bb3
// CHECK-NEXT: br ^bb5
// CHECK-NEXT: cf.br ^bb5
// CHECK-NEXT: ^bb5: // 2 preds: ^bb0, ^bb4
// CHECK-NEXT: return
func @simple_std_2_ifs(%arg0: i1) {
@ -109,27 +109,27 @@ func @simple_std_2_ifs(%arg0: i1) {
}
// CHECK-LABEL: func @simple_std_for_loop_with_2_ifs(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: i1) {
// CHECK-NEXT: br ^bb1(%{{.*}} : index)
// CHECK-NEXT: cf.br ^bb1(%{{.*}} : index)
// CHECK-NEXT: ^bb1(%{{.*}}: index): // 2 preds: ^bb0, ^bb7
// CHECK-NEXT: %{{.*}} = arith.cmpi slt, %{{.*}}, %{{.*}} : index
// CHECK-NEXT: cond_br %{{.*}}, ^bb2, ^bb8
// CHECK-NEXT: cf.cond_br %{{.*}}, ^bb2, ^bb8
// CHECK-NEXT: ^bb2: // pred: ^bb1
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index
// CHECK-NEXT: cond_br %{{.*}}, ^bb3, ^bb7
// CHECK-NEXT: cf.cond_br %{{.*}}, ^bb3, ^bb7
// CHECK-NEXT: ^bb3: // pred: ^bb2
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index
// CHECK-NEXT: cond_br %{{.*}}, ^bb4, ^bb5
// CHECK-NEXT: cf.cond_br %{{.*}}, ^bb4, ^bb5
// CHECK-NEXT: ^bb4: // pred: ^bb3
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index
// CHECK-NEXT: br ^bb6
// CHECK-NEXT: cf.br ^bb6
// CHECK-NEXT: ^bb5: // pred: ^bb3
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index
// CHECK-NEXT: br ^bb6
// CHECK-NEXT: cf.br ^bb6
// CHECK-NEXT: ^bb6: // 2 preds: ^bb4, ^bb5
// CHECK-NEXT: br ^bb7
// CHECK-NEXT: cf.br ^bb7
// CHECK-NEXT: ^bb7: // 2 preds: ^bb2, ^bb6
// CHECK-NEXT: %[[iv0:.*]] = arith.addi %{{.*}}, %{{.*}} : index
// CHECK-NEXT: br ^bb1(%[[iv0]] : index)
// CHECK-NEXT: cf.br ^bb1(%[[iv0]] : index)
// CHECK-NEXT: ^bb8: // pred: ^bb1
// CHECK-NEXT: return
// CHECK-NEXT: }
@ -150,12 +150,12 @@ func @simple_std_for_loop_with_2_ifs(%arg0 : index, %arg1 : index, %arg2 : index
// CHECK-LABEL: func @simple_if_yield
func @simple_if_yield(%arg0: i1) -> (i1, i1) {
// CHECK: cond_br %{{.*}}, ^[[then:.*]], ^[[else:.*]]
// CHECK: cf.cond_br %{{.*}}, ^[[then:.*]], ^[[else:.*]]
%0:2 = scf.if %arg0 -> (i1, i1) {
// CHECK: ^[[then]]:
// CHECK: %[[v0:.*]] = arith.constant false
// CHECK: %[[v1:.*]] = arith.constant true
// CHECK: br ^[[dom:.*]](%[[v0]], %[[v1]] : i1, i1)
// CHECK: cf.br ^[[dom:.*]](%[[v0]], %[[v1]] : i1, i1)
%c0 = arith.constant false
%c1 = arith.constant true
scf.yield %c0, %c1 : i1, i1
@ -163,13 +163,13 @@ func @simple_if_yield(%arg0: i1) -> (i1, i1) {
// CHECK: ^[[else]]:
// CHECK: %[[v2:.*]] = arith.constant false
// CHECK: %[[v3:.*]] = arith.constant true
// CHECK: br ^[[dom]](%[[v3]], %[[v2]] : i1, i1)
// CHECK: cf.br ^[[dom]](%[[v3]], %[[v2]] : i1, i1)
%c0 = arith.constant false
%c1 = arith.constant true
scf.yield %c1, %c0 : i1, i1
}
// CHECK: ^[[dom]](%[[arg1:.*]]: i1, %[[arg2:.*]]: i1):
// CHECK: br ^[[cont:.*]]
// CHECK: cf.br ^[[cont:.*]]
// CHECK: ^[[cont]]:
// CHECK: return %[[arg1]], %[[arg2]]
return %0#0, %0#1 : i1, i1
@ -177,49 +177,49 @@ func @simple_if_yield(%arg0: i1) -> (i1, i1) {
// CHECK-LABEL: func @nested_if_yield
func @nested_if_yield(%arg0: i1) -> (index) {
// CHECK: cond_br %{{.*}}, ^[[first_then:.*]], ^[[first_else:.*]]
// CHECK: cf.cond_br %{{.*}}, ^[[first_then:.*]], ^[[first_else:.*]]
%0 = scf.if %arg0 -> i1 {
// CHECK: ^[[first_then]]:
%1 = arith.constant true
// CHECK: br ^[[first_dom:.*]]({{.*}})
// CHECK: cf.br ^[[first_dom:.*]]({{.*}})
scf.yield %1 : i1
} else {
// CHECK: ^[[first_else]]:
%2 = arith.constant false
// CHECK: br ^[[first_dom]]({{.*}})
// CHECK: cf.br ^[[first_dom]]({{.*}})
scf.yield %2 : i1
}
// CHECK: ^[[first_dom]](%[[arg1:.*]]: i1):
// CHECK: br ^[[first_cont:.*]]
// CHECK: cf.br ^[[first_cont:.*]]
// CHECK: ^[[first_cont]]:
// CHECK: cond_br %[[arg1]], ^[[second_outer_then:.*]], ^[[second_outer_else:.*]]
// CHECK: cf.cond_br %[[arg1]], ^[[second_outer_then:.*]], ^[[second_outer_else:.*]]
%1 = scf.if %0 -> index {
// CHECK: ^[[second_outer_then]]:
// CHECK: cond_br %arg0, ^[[second_inner_then:.*]], ^[[second_inner_else:.*]]
// CHECK: cf.cond_br %arg0, ^[[second_inner_then:.*]], ^[[second_inner_else:.*]]
%3 = scf.if %arg0 -> index {
// CHECK: ^[[second_inner_then]]:
%4 = arith.constant 40 : index
// CHECK: br ^[[second_inner_dom:.*]]({{.*}})
// CHECK: cf.br ^[[second_inner_dom:.*]]({{.*}})
scf.yield %4 : index
} else {
// CHECK: ^[[second_inner_else]]:
%5 = arith.constant 41 : index
// CHECK: br ^[[second_inner_dom]]({{.*}})
// CHECK: cf.br ^[[second_inner_dom]]({{.*}})
scf.yield %5 : index
}
// CHECK: ^[[second_inner_dom]](%[[arg2:.*]]: index):
// CHECK: br ^[[second_inner_cont:.*]]
// CHECK: cf.br ^[[second_inner_cont:.*]]
// CHECK: ^[[second_inner_cont]]:
// CHECK: br ^[[second_outer_dom:.*]]({{.*}})
// CHECK: cf.br ^[[second_outer_dom:.*]]({{.*}})
scf.yield %3 : index
} else {
// CHECK: ^[[second_outer_else]]:
%6 = arith.constant 42 : index
// CHECK: br ^[[second_outer_dom]]({{.*}}
// CHECK: cf.br ^[[second_outer_dom]]({{.*}}
scf.yield %6 : index
}
// CHECK: ^[[second_outer_dom]](%[[arg3:.*]]: index):
// CHECK: br ^[[second_outer_cont:.*]]
// CHECK: cf.br ^[[second_outer_cont:.*]]
// CHECK: ^[[second_outer_cont]]:
// CHECK: return %[[arg3]] : index
return %1 : index
@ -228,22 +228,22 @@ func @nested_if_yield(%arg0: i1) -> (index) {
// CHECK-LABEL: func @parallel_loop(
// CHECK-SAME: [[VAL_0:%.*]]: index, [[VAL_1:%.*]]: index, [[VAL_2:%.*]]: index, [[VAL_3:%.*]]: index, [[VAL_4:%.*]]: index) {
// CHECK: [[VAL_5:%.*]] = arith.constant 1 : index
// CHECK: br ^bb1([[VAL_0]] : index)
// CHECK: cf.br ^bb1([[VAL_0]] : index)
// CHECK: ^bb1([[VAL_6:%.*]]: index):
// CHECK: [[VAL_7:%.*]] = arith.cmpi slt, [[VAL_6]], [[VAL_2]] : index
// CHECK: cond_br [[VAL_7]], ^bb2, ^bb6
// CHECK: cf.cond_br [[VAL_7]], ^bb2, ^bb6
// CHECK: ^bb2:
// CHECK: br ^bb3([[VAL_1]] : index)
// CHECK: cf.br ^bb3([[VAL_1]] : index)
// CHECK: ^bb3([[VAL_8:%.*]]: index):
// CHECK: [[VAL_9:%.*]] = arith.cmpi slt, [[VAL_8]], [[VAL_3]] : index
// CHECK: cond_br [[VAL_9]], ^bb4, ^bb5
// CHECK: cf.cond_br [[VAL_9]], ^bb4, ^bb5
// CHECK: ^bb4:
// CHECK: [[VAL_10:%.*]] = arith.constant 1 : index
// CHECK: [[VAL_11:%.*]] = arith.addi [[VAL_8]], [[VAL_5]] : index
// CHECK: br ^bb3([[VAL_11]] : index)
// CHECK: cf.br ^bb3([[VAL_11]] : index)
// CHECK: ^bb5:
// CHECK: [[VAL_12:%.*]] = arith.addi [[VAL_6]], [[VAL_4]] : index
// CHECK: br ^bb1([[VAL_12]] : index)
// CHECK: cf.br ^bb1([[VAL_12]] : index)
// CHECK: ^bb6:
// CHECK: return
// CHECK: }
@ -262,16 +262,16 @@ func @parallel_loop(%arg0 : index, %arg1 : index, %arg2 : index,
// CHECK-SAME: (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index)
// CHECK: %[[INIT0:.*]] = arith.constant 0
// CHECK: %[[INIT1:.*]] = arith.constant 1
// CHECK: br ^[[COND:.*]](%[[LB]], %[[INIT0]], %[[INIT1]] : index, f32, f32)
// CHECK: cf.br ^[[COND:.*]](%[[LB]], %[[INIT0]], %[[INIT1]] : index, f32, f32)
//
// CHECK: ^[[COND]](%[[ITER:.*]]: index, %[[ITER_ARG0:.*]]: f32, %[[ITER_ARG1:.*]]: f32):
// CHECK: %[[CMP:.*]] = arith.cmpi slt, %[[ITER]], %[[UB]] : index
// CHECK: cond_br %[[CMP]], ^[[BODY:.*]], ^[[CONTINUE:.*]]
// CHECK: cf.cond_br %[[CMP]], ^[[BODY:.*]], ^[[CONTINUE:.*]]
//
// CHECK: ^[[BODY]]:
// CHECK: %[[SUM:.*]] = arith.addf %[[ITER_ARG0]], %[[ITER_ARG1]] : f32
// CHECK: %[[STEPPED:.*]] = arith.addi %[[ITER]], %[[STEP]] : index
// CHECK: br ^[[COND]](%[[STEPPED]], %[[SUM]], %[[SUM]] : index, f32, f32)
// CHECK: cf.br ^[[COND]](%[[STEPPED]], %[[SUM]], %[[SUM]] : index, f32, f32)
//
// CHECK: ^[[CONTINUE]]:
// CHECK: return %[[ITER_ARG0]], %[[ITER_ARG1]] : f32, f32
@ -288,18 +288,18 @@ func @for_yield(%arg0 : index, %arg1 : index, %arg2 : index) -> (f32, f32) {
// CHECK-LABEL: @nested_for_yield
// CHECK-SAME: (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index)
// CHECK: %[[INIT:.*]] = arith.constant
// CHECK: br ^[[COND_OUT:.*]](%[[LB]], %[[INIT]] : index, f32)
// CHECK: cf.br ^[[COND_OUT:.*]](%[[LB]], %[[INIT]] : index, f32)
// CHECK: ^[[COND_OUT]](%[[ITER_OUT:.*]]: index, %[[ARG_OUT:.*]]: f32):
// CHECK: cond_br %{{.*}}, ^[[BODY_OUT:.*]], ^[[CONT_OUT:.*]]
// CHECK: cf.cond_br %{{.*}}, ^[[BODY_OUT:.*]], ^[[CONT_OUT:.*]]
// CHECK: ^[[BODY_OUT]]:
// CHECK: br ^[[COND_IN:.*]](%[[LB]], %[[ARG_OUT]] : index, f32)
// CHECK: cf.br ^[[COND_IN:.*]](%[[LB]], %[[ARG_OUT]] : index, f32)
// CHECK: ^[[COND_IN]](%[[ITER_IN:.*]]: index, %[[ARG_IN:.*]]: f32):
// CHECK: cond_br %{{.*}}, ^[[BODY_IN:.*]], ^[[CONT_IN:.*]]
// CHECK: cf.cond_br %{{.*}}, ^[[BODY_IN:.*]], ^[[CONT_IN:.*]]
// CHECK: ^[[BODY_IN]]
// CHECK: %[[RES:.*]] = arith.addf
// CHECK: br ^[[COND_IN]](%{{.*}}, %[[RES]] : index, f32)
// CHECK: cf.br ^[[COND_IN]](%{{.*}}, %[[RES]] : index, f32)
// CHECK: ^[[CONT_IN]]:
// CHECK: br ^[[COND_OUT]](%{{.*}}, %[[ARG_IN]] : index, f32)
// CHECK: cf.br ^[[COND_OUT]](%{{.*}}, %[[ARG_IN]] : index, f32)
// CHECK: ^[[CONT_OUT]]:
// CHECK: return %[[ARG_OUT]] : f32
func @nested_for_yield(%arg0 : index, %arg1 : index, %arg2 : index) -> f32 {
@ -325,13 +325,13 @@ func @simple_parallel_reduce_loop(%arg0: index, %arg1: index,
// passed across as a block argument.
// Branch to the condition block passing in the initial reduction value.
// CHECK: br ^[[COND:.*]](%[[LB]], %[[INIT]]
// CHECK: cf.br ^[[COND:.*]](%[[LB]], %[[INIT]]
// Condition branch takes as arguments the current value of the iteration
// variable and the current partially reduced value.
// CHECK: ^[[COND]](%[[ITER:.*]]: index, %[[ITER_ARG:.*]]: f32
// CHECK: %[[COMP:.*]] = arith.cmpi slt, %[[ITER]], %[[UB]]
// CHECK: cond_br %[[COMP]], ^[[BODY:.*]], ^[[CONTINUE:.*]]
// CHECK: cf.cond_br %[[COMP]], ^[[BODY:.*]], ^[[CONTINUE:.*]]
// Bodies of scf.reduce operations are folded into the main loop body. The
// result of this partial reduction is passed as argument to the condition
@ -340,7 +340,7 @@ func @simple_parallel_reduce_loop(%arg0: index, %arg1: index,
// CHECK: %[[CST:.*]] = arith.constant 4.2
// CHECK: %[[PROD:.*]] = arith.mulf %[[ITER_ARG]], %[[CST]]
// CHECK: %[[INCR:.*]] = arith.addi %[[ITER]], %[[STEP]]
// CHECK: br ^[[COND]](%[[INCR]], %[[PROD]]
// CHECK: cf.br ^[[COND]](%[[INCR]], %[[PROD]]
// The continuation block has access to the (last value of) reduction.
// CHECK: ^[[CONTINUE]]:
@ -363,19 +363,19 @@ func @parallel_reduce_loop(%arg0 : index, %arg1 : index, %arg2 : index,
// Multiple reduction blocks should be folded in the same body, and the
// reduction value must be forwarded through block structures.
// CHECK: %[[INIT2:.*]] = arith.constant 42
// CHECK: br ^[[COND_OUT:.*]](%{{.*}}, %[[INIT1]], %[[INIT2]]
// CHECK: cf.br ^[[COND_OUT:.*]](%{{.*}}, %[[INIT1]], %[[INIT2]]
// CHECK: ^[[COND_OUT]](%{{.*}}: index, %[[ITER_ARG1_OUT:.*]]: f32, %[[ITER_ARG2_OUT:.*]]: i64
// CHECK: cond_br %{{.*}}, ^[[BODY_OUT:.*]], ^[[CONT_OUT:.*]]
// CHECK: cf.cond_br %{{.*}}, ^[[BODY_OUT:.*]], ^[[CONT_OUT:.*]]
// CHECK: ^[[BODY_OUT]]:
// CHECK: br ^[[COND_IN:.*]](%{{.*}}, %[[ITER_ARG1_OUT]], %[[ITER_ARG2_OUT]]
// CHECK: cf.br ^[[COND_IN:.*]](%{{.*}}, %[[ITER_ARG1_OUT]], %[[ITER_ARG2_OUT]]
// CHECK: ^[[COND_IN]](%{{.*}}: index, %[[ITER_ARG1_IN:.*]]: f32, %[[ITER_ARG2_IN:.*]]: i64
// CHECK: cond_br %{{.*}}, ^[[BODY_IN:.*]], ^[[CONT_IN:.*]]
// CHECK: cf.cond_br %{{.*}}, ^[[BODY_IN:.*]], ^[[CONT_IN:.*]]
// CHECK: ^[[BODY_IN]]:
// CHECK: %[[REDUCE1:.*]] = arith.addf %[[ITER_ARG1_IN]], %{{.*}}
// CHECK: %[[REDUCE2:.*]] = arith.ori %[[ITER_ARG2_IN]], %{{.*}}
// CHECK: br ^[[COND_IN]](%{{.*}}, %[[REDUCE1]], %[[REDUCE2]]
// CHECK: cf.br ^[[COND_IN]](%{{.*}}, %[[REDUCE1]], %[[REDUCE2]]
// CHECK: ^[[CONT_IN]]:
// CHECK: br ^[[COND_OUT]](%{{.*}}, %[[ITER_ARG1_IN]], %[[ITER_ARG2_IN]]
// CHECK: cf.br ^[[COND_OUT]](%{{.*}}, %[[ITER_ARG1_IN]], %[[ITER_ARG2_IN]]
// CHECK: ^[[CONT_OUT]]:
// CHECK: return %[[ITER_ARG1_OUT]], %[[ITER_ARG2_OUT]]
%step = arith.constant 1 : index
@ -416,17 +416,17 @@ func @unknown_op_inside_loop(%arg0: index, %arg1: index, %arg2: index) {
// CHECK-LABEL: @minimal_while
func @minimal_while() {
// CHECK: %[[COND:.*]] = "test.make_condition"() : () -> i1
// CHECK: br ^[[BEFORE:.*]]
// CHECK: cf.br ^[[BEFORE:.*]]
%0 = "test.make_condition"() : () -> i1
scf.while : () -> () {
// CHECK: ^[[BEFORE]]:
// CHECK: cond_br %[[COND]], ^[[AFTER:.*]], ^[[CONT:.*]]
// CHECK: cf.cond_br %[[COND]], ^[[AFTER:.*]], ^[[CONT:.*]]
scf.condition(%0)
} do {
// CHECK: ^[[AFTER]]:
// CHECK: "test.some_payload"() : () -> ()
"test.some_payload"() : () -> ()
// CHECK: br ^[[BEFORE]]
// CHECK: cf.br ^[[BEFORE]]
scf.yield
}
// CHECK: ^[[CONT]]:
@ -436,16 +436,16 @@ func @minimal_while() {
// CHECK-LABEL: @do_while
func @do_while(%arg0: f32) {
// CHECK: br ^[[BEFORE:.*]]({{.*}}: f32)
// CHECK: cf.br ^[[BEFORE:.*]]({{.*}}: f32)
scf.while (%arg1 = %arg0) : (f32) -> (f32) {
// CHECK: ^[[BEFORE]](%[[VAL:.*]]: f32):
// CHECK: %[[COND:.*]] = "test.make_condition"() : () -> i1
%0 = "test.make_condition"() : () -> i1
// CHECK: cond_br %[[COND]], ^[[BEFORE]](%[[VAL]] : f32), ^[[CONT:.*]]
// CHECK: cf.cond_br %[[COND]], ^[[BEFORE]](%[[VAL]] : f32), ^[[CONT:.*]]
scf.condition(%0) %arg1 : f32
} do {
^bb0(%arg2: f32):
// CHECK-NOT: br ^[[BEFORE]]
// CHECK-NOT: cf.br ^[[BEFORE]]
scf.yield %arg2 : f32
}
// CHECK: ^[[CONT]]:
@ -460,21 +460,21 @@ func @while_values(%arg0: i32, %arg1: f32) {
%0 = "test.make_condition"() : () -> i1
%c0_i32 = arith.constant 0 : i32
%cst = arith.constant 0.000000e+00 : f32
// CHECK: br ^[[BEFORE:.*]](%[[ARG0]], %[[ARG1]] : i32, f32)
// CHECK: cf.br ^[[BEFORE:.*]](%[[ARG0]], %[[ARG1]] : i32, f32)
%1:2 = scf.while (%arg2 = %arg0, %arg3 = %arg1) : (i32, f32) -> (i64, f64) {
// CHECK: ^bb1(%[[ARG2:.*]]: i32, %[[ARG3:.]]: f32):
// CHECK: %[[VAL1:.*]] = arith.extui %[[ARG0]] : i32 to i64
%2 = arith.extui %arg0 : i32 to i64
// CHECK: %[[VAL2:.*]] = arith.extf %[[ARG3]] : f32 to f64
%3 = arith.extf %arg3 : f32 to f64
// CHECK: cond_br %[[COND]],
// CHECK: cf.cond_br %[[COND]],
// CHECK: ^[[AFTER:.*]](%[[VAL1]], %[[VAL2]] : i64, f64),
// CHECK: ^[[CONT:.*]]
scf.condition(%0) %2, %3 : i64, f64
} do {
// CHECK: ^[[AFTER]](%[[ARG4:.*]]: i64, %[[ARG5:.*]]: f64):
^bb0(%arg2: i64, %arg3: f64):
// CHECK: br ^[[BEFORE]](%{{.*}}, %{{.*}} : i32, f32)
// CHECK: cf.br ^[[BEFORE]](%{{.*}}, %{{.*}} : i32, f32)
scf.yield %c0_i32, %cst : i32, f32
}
// CHECK: ^bb3:
@ -484,17 +484,17 @@ func @while_values(%arg0: i32, %arg1: f32) {
// CHECK-LABEL: @nested_while_ops
func @nested_while_ops(%arg0: f32) -> i64 {
// CHECK: br ^[[OUTER_BEFORE:.*]](%{{.*}} : f32)
// CHECK: cf.br ^[[OUTER_BEFORE:.*]](%{{.*}} : f32)
%0 = scf.while(%outer = %arg0) : (f32) -> i64 {
// CHECK: ^[[OUTER_BEFORE]](%{{.*}}: f32):
// CHECK: %[[OUTER_COND:.*]] = "test.outer_before_pre"() : () -> i1
%cond = "test.outer_before_pre"() : () -> i1
// CHECK: br ^[[INNER_BEFORE_BEFORE:.*]](%{{.*}} : f32)
// CHECK: cf.br ^[[INNER_BEFORE_BEFORE:.*]](%{{.*}} : f32)
%1 = scf.while(%inner = %outer) : (f32) -> i64 {
// CHECK: ^[[INNER_BEFORE_BEFORE]](%{{.*}}: f32):
// CHECK: %[[INNER1:.*]]:2 = "test.inner_before"(%{{.*}}) : (f32) -> (i1, i64)
%2:2 = "test.inner_before"(%inner) : (f32) -> (i1, i64)
// CHECK: cond_br %[[INNER1]]#0,
// CHECK: cf.cond_br %[[INNER1]]#0,
// CHECK: ^[[INNER_BEFORE_AFTER:.*]](%[[INNER1]]#1 : i64),
// CHECK: ^[[OUTER_BEFORE_LAST:.*]]
scf.condition(%2#0) %2#1 : i64
@ -503,13 +503,13 @@ func @nested_while_ops(%arg0: f32) -> i64 {
^bb0(%arg1: i64):
// CHECK: %[[INNER2:.*]] = "test.inner_after"(%{{.*}}) : (i64) -> f32
%3 = "test.inner_after"(%arg1) : (i64) -> f32
// CHECK: br ^[[INNER_BEFORE_BEFORE]](%[[INNER2]] : f32)
// CHECK: cf.br ^[[INNER_BEFORE_BEFORE]](%[[INNER2]] : f32)
scf.yield %3 : f32
}
// CHECK: ^[[OUTER_BEFORE_LAST]]:
// CHECK: "test.outer_before_post"() : () -> ()
"test.outer_before_post"() : () -> ()
// CHECK: cond_br %[[OUTER_COND]],
// CHECK: cf.cond_br %[[OUTER_COND]],
// CHECK: ^[[OUTER_AFTER:.*]](%[[INNER1]]#1 : i64),
// CHECK: ^[[CONTINUATION:.*]]
scf.condition(%cond) %1 : i64
@ -518,12 +518,12 @@ func @nested_while_ops(%arg0: f32) -> i64 {
^bb2(%arg2: i64):
// CHECK: "test.outer_after_pre"(%{{.*}}) : (i64) -> ()
"test.outer_after_pre"(%arg2) : (i64) -> ()
// CHECK: br ^[[INNER_AFTER_BEFORE:.*]](%{{.*}} : i64)
// CHECK: cf.br ^[[INNER_AFTER_BEFORE:.*]](%{{.*}} : i64)
%4 = scf.while(%inner = %arg2) : (i64) -> f32 {
// CHECK: ^[[INNER_AFTER_BEFORE]](%{{.*}}: i64):
// CHECK: %[[INNER3:.*]]:2 = "test.inner2_before"(%{{.*}}) : (i64) -> (i1, f32)
%5:2 = "test.inner2_before"(%inner) : (i64) -> (i1, f32)
// CHECK: cond_br %[[INNER3]]#0,
// CHECK: cf.cond_br %[[INNER3]]#0,
// CHECK: ^[[INNER_AFTER_AFTER:.*]](%[[INNER3]]#1 : f32),
// CHECK: ^[[OUTER_AFTER_LAST:.*]]
scf.condition(%5#0) %5#1 : f32
@ -532,13 +532,13 @@ func @nested_while_ops(%arg0: f32) -> i64 {
^bb3(%arg3: f32):
// CHECK: %{{.*}} = "test.inner2_after"(%{{.*}}) : (f32) -> i64
%6 = "test.inner2_after"(%arg3) : (f32) -> i64
// CHECK: br ^[[INNER_AFTER_BEFORE]](%{{.*}} : i64)
// CHECK: cf.br ^[[INNER_AFTER_BEFORE]](%{{.*}} : i64)
scf.yield %6 : i64
}
// CHECK: ^[[OUTER_AFTER_LAST]]:
// CHECK: "test.outer_after_post"() : () -> ()
"test.outer_after_post"() : () -> ()
// CHECK: br ^[[OUTER_BEFORE]](%[[INNER3]]#1 : f32)
// CHECK: cf.br ^[[OUTER_BEFORE]](%[[INNER3]]#1 : f32)
scf.yield %4 : f32
}
// CHECK: ^[[CONTINUATION]]:
@ -549,27 +549,27 @@ func @nested_while_ops(%arg0: f32) -> i64 {
// CHECK-LABEL: @ifs_in_parallel
// CHECK: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: i1, %[[ARG4:.*]]: i1)
func @ifs_in_parallel(%arg1: index, %arg2: index, %arg3: index, %arg4: i1, %arg5: i1) {
// CHECK: br ^[[LOOP_LATCH:.*]](%[[ARG0]] : index)
// CHECK: cf.br ^[[LOOP_LATCH:.*]](%[[ARG0]] : index)
// CHECK: ^[[LOOP_LATCH]](%[[LOOP_IV:.*]]: index):
// CHECK: %[[LOOP_COND:.*]] = arith.cmpi slt, %[[LOOP_IV]], %[[ARG1]] : index
// CHECK: cond_br %[[LOOP_COND]], ^[[LOOP_BODY:.*]], ^[[LOOP_CONT:.*]]
// CHECK: cf.cond_br %[[LOOP_COND]], ^[[LOOP_BODY:.*]], ^[[LOOP_CONT:.*]]
// CHECK: ^[[LOOP_BODY]]:
// CHECK: cond_br %[[ARG3]], ^[[IF1_THEN:.*]], ^[[IF1_CONT:.*]]
// CHECK: cf.cond_br %[[ARG3]], ^[[IF1_THEN:.*]], ^[[IF1_CONT:.*]]
// CHECK: ^[[IF1_THEN]]:
// CHECK: cond_br %[[ARG4]], ^[[IF2_THEN:.*]], ^[[IF2_ELSE:.*]]
// CHECK: cf.cond_br %[[ARG4]], ^[[IF2_THEN:.*]], ^[[IF2_ELSE:.*]]
// CHECK: ^[[IF2_THEN]]:
// CHECK: %{{.*}} = "test.if2"() : () -> index
// CHECK: br ^[[IF2_MERGE:.*]](%{{.*}} : index)
// CHECK: cf.br ^[[IF2_MERGE:.*]](%{{.*}} : index)
// CHECK: ^[[IF2_ELSE]]:
// CHECK: %{{.*}} = "test.else2"() : () -> index
// CHECK: br ^[[IF2_MERGE]](%{{.*}} : index)
// CHECK: cf.br ^[[IF2_MERGE]](%{{.*}} : index)
// CHECK: ^[[IF2_MERGE]](%{{.*}}: index):
// CHECK: br ^[[IF2_CONT:.*]]
// CHECK: cf.br ^[[IF2_CONT:.*]]
// CHECK: ^[[IF2_CONT]]:
// CHECK: br ^[[IF1_CONT]]
// CHECK: cf.br ^[[IF1_CONT]]
// CHECK: ^[[IF1_CONT]]:
// CHECK: %{{.*}} = arith.addi %[[LOOP_IV]], %[[ARG2]] : index
// CHECK: br ^[[LOOP_LATCH]](%{{.*}} : index)
// CHECK: cf.br ^[[LOOP_LATCH]](%{{.*}} : index)
scf.parallel (%i) = (%arg1) to (%arg2) step (%arg3) {
scf.if %arg4 {
%0 = scf.if %arg5 -> (index) {
@ -593,7 +593,7 @@ func @func_execute_region_elim_multi_yield() {
"test.foo"() : () -> ()
%v = scf.execute_region -> i64 {
%c = "test.cmp"() : () -> i1
cond_br %c, ^bb2, ^bb3
cf.cond_br %c, ^bb2, ^bb3
^bb2:
%x = "test.val1"() : () -> i64
scf.yield %x : i64
@ -607,16 +607,16 @@ func @func_execute_region_elim_multi_yield() {
// CHECK-NOT: execute_region
// CHECK: "test.foo"
// CHECK: br ^[[rentry:.+]]
// CHECK: cf.br ^[[rentry:.+]]
// CHECK: ^[[rentry]]
// CHECK: %[[cmp:.+]] = "test.cmp"
// CHECK: cond_br %[[cmp]], ^[[bb1:.+]], ^[[bb2:.+]]
// CHECK: cf.cond_br %[[cmp]], ^[[bb1:.+]], ^[[bb2:.+]]
// CHECK: ^[[bb1]]:
// CHECK: %[[x:.+]] = "test.val1"
// CHECK: br ^[[bb3:.+]](%[[x]] : i64)
// CHECK: cf.br ^[[bb3:.+]](%[[x]] : i64)
// CHECK: ^[[bb2]]:
// CHECK: %[[y:.+]] = "test.val2"
// CHECK: br ^[[bb3]](%[[y:.+]] : i64)
// CHECK: cf.br ^[[bb3]](%[[y:.+]] : i64)
// CHECK: ^[[bb3]](%[[z:.+]]: i64):
// CHECK: "test.bar"(%[[z]])
// CHECK: return

View File

@ -6,7 +6,7 @@
// CHECK-SAME: %[[RHS:.*]]: tensor<?xindex>) -> !shape.witness {
// CHECK: %[[RET:.*]] = shape.const_witness true
// CHECK: %[[BROADCAST_IS_VALID:.*]] = shape.is_broadcastable %[[LHS]], %[[RHS]]
// CHECK: assert %[[BROADCAST_IS_VALID]], "required broadcastable shapes"
// CHECK: cf.assert %[[BROADCAST_IS_VALID]], "required broadcastable shapes"
// CHECK: return %[[RET]] : !shape.witness
// CHECK: }
func @cstr_broadcastable(%arg0: tensor<?xindex>, %arg1: tensor<?xindex>) -> !shape.witness {
@ -19,7 +19,7 @@ func @cstr_broadcastable(%arg0: tensor<?xindex>, %arg1: tensor<?xindex>) -> !sha
// CHECK-SAME: %[[RHS:.*]]: tensor<?xindex>) -> !shape.witness {
// CHECK: %[[RET:.*]] = shape.const_witness true
// CHECK: %[[EQUAL_IS_VALID:.*]] = shape.shape_eq %[[LHS]], %[[RHS]]
// CHECK: assert %[[EQUAL_IS_VALID]], "required equal shapes"
// CHECK: cf.assert %[[EQUAL_IS_VALID]], "required equal shapes"
// CHECK: return %[[RET]] : !shape.witness
// CHECK: }
func @cstr_eq(%arg0: tensor<?xindex>, %arg1: tensor<?xindex>) -> !shape.witness {
@ -30,7 +30,7 @@ func @cstr_eq(%arg0: tensor<?xindex>, %arg1: tensor<?xindex>) -> !shape.witness
// CHECK-LABEL: func @cstr_require
func @cstr_require(%arg0: i1) -> !shape.witness {
// CHECK: %[[RET:.*]] = shape.const_witness true
// CHECK: assert %arg0, "msg"
// CHECK: cf.assert %arg0, "msg"
// CHECK: return %[[RET]]
%witness = shape.cstr_require %arg0, "msg"
return %witness : !shape.witness

View File

@ -29,7 +29,7 @@ func private @memref_call_conv_nested(%arg0: (memref<?xf32>) -> ())
//CHECK-LABEL: llvm.func @pass_through(%arg0: !llvm.ptr<func<void ()>>) -> !llvm.ptr<func<void ()>> {
func @pass_through(%arg0: () -> ()) -> (() -> ()) {
// CHECK-NEXT: llvm.br ^bb1(%arg0 : !llvm.ptr<func<void ()>>)
br ^bb1(%arg0 : () -> ())
cf.br ^bb1(%arg0 : () -> ())
//CHECK-NEXT: ^bb1(%0: !llvm.ptr<func<void ()>>):
^bb1(%bbarg: () -> ()):

View File

@ -109,17 +109,17 @@ func @loop_carried(%arg0 : index, %arg1 : index, %arg2 : index, %base0 : !base_t
// This test checks that in the BAREPTR case, the branch arguments only forward the descriptor.
// This test was lowered from a simple scf.for that swaps 2 memref iter_args.
// BAREPTR: llvm.br ^bb1(%{{.*}}, %{{.*}}, %{{.*}} : i64, !llvm.struct<(ptr<i32, 201>, ptr<i32, 201>, i64, array<1 x i64>, array<1 x i64>)>, !llvm.struct<(ptr<i32, 201>, ptr<i32, 201>, i64, array<1 x i64>, array<1 x i64>)>)
br ^bb1(%arg0, %base0, %base1 : index, memref<64xi32, 201>, memref<64xi32, 201>)
cf.br ^bb1(%arg0, %base0, %base1 : index, memref<64xi32, 201>, memref<64xi32, 201>)
// BAREPTR-NEXT: ^bb1
// BAREPTR-NEXT: llvm.icmp
// BAREPTR-NEXT: llvm.cond_br %{{.*}}, ^bb2, ^bb3
^bb1(%0: index, %1: memref<64xi32, 201>, %2: memref<64xi32, 201>): // 2 preds: ^bb0, ^bb2
%3 = arith.cmpi slt, %0, %arg1 : index
cond_br %3, ^bb2, ^bb3
cf.cond_br %3, ^bb2, ^bb3
^bb2: // pred: ^bb1
%4 = arith.addi %0, %arg2 : index
br ^bb1(%4, %2, %1 : index, memref<64xi32, 201>, memref<64xi32, 201>)
cf.br ^bb1(%4, %2, %1 : index, memref<64xi32, 201>, memref<64xi32, 201>)
^bb3: // pred: ^bb1
return %1, %2 : memref<64xi32, 201>, memref<64xi32, 201>
}

View File

@ -18,7 +18,7 @@ func @simple_loop() {
^bb0:
// CHECK-NEXT: llvm.br ^bb1
// CHECK32-NEXT: llvm.br ^bb1
br ^bb1
cf.br ^bb1
// CHECK-NEXT: ^bb1: // pred: ^bb0
// CHECK-NEXT: {{.*}} = llvm.mlir.constant(1 : index) : i64
@ -31,7 +31,7 @@ func @simple_loop() {
^bb1: // pred: ^bb0
%c1 = arith.constant 1 : index
%c42 = arith.constant 42 : index
br ^bb2(%c1 : index)
cf.br ^bb2(%c1 : index)
// CHECK: ^bb2({{.*}}: i64): // 2 preds: ^bb1, ^bb3
// CHECK-NEXT: {{.*}} = llvm.icmp "slt" {{.*}}, {{.*}} : i64
@ -41,7 +41,7 @@ func @simple_loop() {
// CHECK32-NEXT: llvm.cond_br {{.*}}, ^bb3, ^bb4
^bb2(%0: index): // 2 preds: ^bb1, ^bb3
%1 = arith.cmpi slt, %0, %c42 : index
cond_br %1, ^bb3, ^bb4
cf.cond_br %1, ^bb3, ^bb4
// CHECK: ^bb3: // pred: ^bb2
// CHECK-NEXT: llvm.call @body({{.*}}) : (i64) -> ()
@ -57,7 +57,7 @@ func @simple_loop() {
call @body(%0) : (index) -> ()
%c1_0 = arith.constant 1 : index
%2 = arith.addi %0, %c1_0 : index
br ^bb2(%2 : index)
cf.br ^bb2(%2 : index)
// CHECK: ^bb4: // pred: ^bb2
// CHECK-NEXT: llvm.return
@ -111,7 +111,7 @@ func private @other(index, i32) -> i32
func @func_args(i32, i32) -> i32 {
^bb0(%arg0: i32, %arg1: i32):
%c0_i32 = arith.constant 0 : i32
br ^bb1
cf.br ^bb1
// CHECK-NEXT: ^bb1: // pred: ^bb0
// CHECK-NEXT: {{.*}} = llvm.mlir.constant(0 : index) : i64
@ -124,7 +124,7 @@ func @func_args(i32, i32) -> i32 {
^bb1: // pred: ^bb0
%c0 = arith.constant 0 : index
%c42 = arith.constant 42 : index
br ^bb2(%c0 : index)
cf.br ^bb2(%c0 : index)
// CHECK-NEXT: ^bb2({{.*}}: i64): // 2 preds: ^bb1, ^bb3
// CHECK-NEXT: {{.*}} = llvm.icmp "slt" {{.*}}, {{.*}} : i64
@ -134,7 +134,7 @@ func @func_args(i32, i32) -> i32 {
// CHECK32-NEXT: llvm.cond_br {{.*}}, ^bb3, ^bb4
^bb2(%0: index): // 2 preds: ^bb1, ^bb3
%1 = arith.cmpi slt, %0, %c42 : index
cond_br %1, ^bb3, ^bb4
cf.cond_br %1, ^bb3, ^bb4
// CHECK-NEXT: ^bb3: // pred: ^bb2
// CHECK-NEXT: {{.*}} = llvm.call @body_args({{.*}}) : (i64) -> i64
@ -159,7 +159,7 @@ func @func_args(i32, i32) -> i32 {
%5 = call @other(%2, %arg1) : (index, i32) -> i32
%c1 = arith.constant 1 : index
%6 = arith.addi %0, %c1 : index
br ^bb2(%6 : index)
cf.br ^bb2(%6 : index)
// CHECK-NEXT: ^bb4: // pred: ^bb2
// CHECK-NEXT: {{.*}} = llvm.mlir.constant(0 : index) : i64
@ -191,7 +191,7 @@ func private @post(index)
// CHECK-NEXT: llvm.br ^bb1
func @imperfectly_nested_loops() {
^bb0:
br ^bb1
cf.br ^bb1
// CHECK-NEXT: ^bb1: // pred: ^bb0
// CHECK-NEXT: {{.*}} = llvm.mlir.constant(0 : index) : i64
@ -200,21 +200,21 @@ func @imperfectly_nested_loops() {
^bb1: // pred: ^bb0
%c0 = arith.constant 0 : index
%c42 = arith.constant 42 : index
br ^bb2(%c0 : index)
cf.br ^bb2(%c0 : index)
// CHECK-NEXT: ^bb2({{.*}}: i64): // 2 preds: ^bb1, ^bb7
// CHECK-NEXT: {{.*}} = llvm.icmp "slt" {{.*}}, {{.*}} : i64
// CHECK-NEXT: llvm.cond_br {{.*}}, ^bb3, ^bb8
^bb2(%0: index): // 2 preds: ^bb1, ^bb7
%1 = arith.cmpi slt, %0, %c42 : index
cond_br %1, ^bb3, ^bb8
cf.cond_br %1, ^bb3, ^bb8
// CHECK-NEXT: ^bb3:
// CHECK-NEXT: llvm.call @pre({{.*}}) : (i64) -> ()
// CHECK-NEXT: llvm.br ^bb4
^bb3: // pred: ^bb2
call @pre(%0) : (index) -> ()
br ^bb4
cf.br ^bb4
// CHECK-NEXT: ^bb4: // pred: ^bb3
// CHECK-NEXT: {{.*}} = llvm.mlir.constant(7 : index) : i64
@ -223,14 +223,14 @@ func @imperfectly_nested_loops() {
^bb4: // pred: ^bb3
%c7 = arith.constant 7 : index
%c56 = arith.constant 56 : index
br ^bb5(%c7 : index)
cf.br ^bb5(%c7 : index)
// CHECK-NEXT: ^bb5({{.*}}: i64): // 2 preds: ^bb4, ^bb6
// CHECK-NEXT: {{.*}} = llvm.icmp "slt" {{.*}}, {{.*}} : i64
// CHECK-NEXT: llvm.cond_br {{.*}}, ^bb6, ^bb7
^bb5(%2: index): // 2 preds: ^bb4, ^bb6
%3 = arith.cmpi slt, %2, %c56 : index
cond_br %3, ^bb6, ^bb7
cf.cond_br %3, ^bb6, ^bb7
// CHECK-NEXT: ^bb6: // pred: ^bb5
// CHECK-NEXT: llvm.call @body2({{.*}}, {{.*}}) : (i64, i64) -> ()
@ -241,7 +241,7 @@ func @imperfectly_nested_loops() {
call @body2(%0, %2) : (index, index) -> ()
%c2 = arith.constant 2 : index
%4 = arith.addi %2, %c2 : index
br ^bb5(%4 : index)
cf.br ^bb5(%4 : index)
// CHECK-NEXT: ^bb7: // pred: ^bb5
// CHECK-NEXT: llvm.call @post({{.*}}) : (i64) -> ()
@ -252,7 +252,7 @@ func @imperfectly_nested_loops() {
call @post(%0) : (index) -> ()
%c1 = arith.constant 1 : index
%5 = arith.addi %0, %c1 : index
br ^bb2(%5 : index)
cf.br ^bb2(%5 : index)
// CHECK-NEXT: ^bb8: // pred: ^bb2
// CHECK-NEXT: llvm.return
@ -316,49 +316,49 @@ func private @body3(index, index)
// CHECK-NEXT: }
func @more_imperfectly_nested_loops() {
^bb0:
br ^bb1
cf.br ^bb1
^bb1: // pred: ^bb0
%c0 = arith.constant 0 : index
%c42 = arith.constant 42 : index
br ^bb2(%c0 : index)
cf.br ^bb2(%c0 : index)
^bb2(%0: index): // 2 preds: ^bb1, ^bb11
%1 = arith.cmpi slt, %0, %c42 : index
cond_br %1, ^bb3, ^bb12
cf.cond_br %1, ^bb3, ^bb12
^bb3: // pred: ^bb2
call @pre(%0) : (index) -> ()
br ^bb4
cf.br ^bb4
^bb4: // pred: ^bb3
%c7 = arith.constant 7 : index
%c56 = arith.constant 56 : index
br ^bb5(%c7 : index)
cf.br ^bb5(%c7 : index)
^bb5(%2: index): // 2 preds: ^bb4, ^bb6
%3 = arith.cmpi slt, %2, %c56 : index
cond_br %3, ^bb6, ^bb7
cf.cond_br %3, ^bb6, ^bb7
^bb6: // pred: ^bb5
call @body2(%0, %2) : (index, index) -> ()
%c2 = arith.constant 2 : index
%4 = arith.addi %2, %c2 : index
br ^bb5(%4 : index)
cf.br ^bb5(%4 : index)
^bb7: // pred: ^bb5
call @mid(%0) : (index) -> ()
br ^bb8
cf.br ^bb8
^bb8: // pred: ^bb7
%c18 = arith.constant 18 : index
%c37 = arith.constant 37 : index
br ^bb9(%c18 : index)
cf.br ^bb9(%c18 : index)
^bb9(%5: index): // 2 preds: ^bb8, ^bb10
%6 = arith.cmpi slt, %5, %c37 : index
cond_br %6, ^bb10, ^bb11
cf.cond_br %6, ^bb10, ^bb11
^bb10: // pred: ^bb9
call @body3(%0, %5) : (index, index) -> ()
%c3 = arith.constant 3 : index
%7 = arith.addi %5, %c3 : index
br ^bb9(%7 : index)
cf.br ^bb9(%7 : index)
^bb11: // pred: ^bb9
call @post(%0) : (index) -> ()
%c1 = arith.constant 1 : index
%8 = arith.addi %0, %c1 : index
br ^bb2(%8 : index)
cf.br ^bb2(%8 : index)
^bb12: // pred: ^bb2
return
}
@ -432,7 +432,7 @@ func @dfs_block_order(%arg0: i32) -> (i32) {
// CHECK-NEXT: %[[CST:.*]] = llvm.mlir.constant(42 : i32) : i32
%0 = arith.constant 42 : i32
// CHECK-NEXT: llvm.br ^bb2
br ^bb2
cf.br ^bb2
// CHECK-NEXT: ^bb1:
// CHECK-NEXT: %[[ADD:.*]] = llvm.add %arg0, %[[CST]] : i32
@ -444,7 +444,7 @@ func @dfs_block_order(%arg0: i32) -> (i32) {
// CHECK-NEXT: ^bb2:
^bb2:
// CHECK-NEXT: llvm.br ^bb1
br ^bb1
cf.br ^bb1
}
// -----
@ -469,7 +469,7 @@ func @floorf(%arg0 : f32) {
// -----
// Lowers `assert` to a function call to `abort` if the assertion is violated.
// Lowers `cf.assert` to a function call to `abort` if the assertion is violated.
// CHECK: llvm.func @abort()
// CHECK-LABEL: @assert_test_function
// CHECK-SAME: (%[[ARG:.*]]: i1)
@ -480,7 +480,7 @@ func @assert_test_function(%arg : i1) {
// CHECK: ^[[FAILURE_BLOCK]]:
// CHECK: llvm.call @abort() : () -> ()
// CHECK: llvm.unreachable
assert %arg, "Computer says no"
cf.assert %arg, "Computer says no"
return
}
@ -514,8 +514,8 @@ func @fmaf(%arg0: f32, %arg1: vector<4xf32>) {
// CHECK-LABEL: func @switchi8(
func @switchi8(%arg0 : i8) -> i32 {
switch %arg0 : i8, [
default: ^bb1,
cf.switch %arg0 : i8, [
default: ^bb1,
42: ^bb1,
43: ^bb3
]

View File

@ -900,45 +900,3 @@ func @tensor_extract_constant(%a : index, %b: index, %c: index) -> i32 {
// CHECK: spv.ReturnValue %[[VAL]]
return %extract : i32
}
// -----
//===----------------------------------------------------------------------===//
// std.br, std.cond_br
//===----------------------------------------------------------------------===//
module attributes {
spv.target_env = #spv.target_env<#spv.vce<v1.0, [], []>, {}>
} {
// CHECK-LABEL: func @simple_loop
func @simple_loop(index, index, index) {
^bb0(%begin : index, %end : index, %step : index):
// CHECK-NEXT: spv.Branch ^bb1
br ^bb1
// CHECK-NEXT: ^bb1: // pred: ^bb0
// CHECK-NEXT: spv.Branch ^bb2({{.*}} : i32)
^bb1: // pred: ^bb0
br ^bb2(%begin : index)
// CHECK: ^bb2({{.*}}: i32): // 2 preds: ^bb1, ^bb3
// CHECK-NEXT: {{.*}} = spv.SLessThan {{.*}}, {{.*}} : i32
// CHECK-NEXT: spv.BranchConditional {{.*}}, ^bb3, ^bb4
^bb2(%0: index): // 2 preds: ^bb1, ^bb3
%1 = arith.cmpi slt, %0, %end : index
cond_br %1, ^bb3, ^bb4
// CHECK: ^bb3: // pred: ^bb2
// CHECK-NEXT: {{.*}} = spv.IAdd {{.*}}, {{.*}} : i32
// CHECK-NEXT: spv.Branch ^bb2({{.*}} : i32)
^bb3: // pred: ^bb2
%2 = arith.addi %0, %step : index
br ^bb2(%2 : index)
// CHECK: ^bb4: // pred: ^bb2
^bb4: // pred: ^bb2
return
}
}

View File

@ -56,9 +56,9 @@ func @affine_load_invalid_dim(%M : memref<10xi32>) {
^bb0(%arg: index):
affine.load %M[%arg] : memref<10xi32>
// expected-error@-1 {{index must be a dimension or symbol identifier}}
br ^bb1
cf.br ^bb1
^bb1:
br ^bb1
cf.br ^bb1
}) : () -> ()
return
}

View File

@ -54,13 +54,13 @@ func @token_value_to_func() {
// CHECK-LABEL: @token_arg_cond_br_await_with_fallthough
// CHECK: %[[TOKEN:.*]]: !async.token
func @token_arg_cond_br_await_with_fallthough(%arg0: !async.token, %arg1: i1) {
// CHECK: cond_br
// CHECK: cf.cond_br
// CHECK-SAME: ^[[BB1:.*]], ^[[BB2:.*]]
cond_br %arg1, ^bb1, ^bb2
cf.cond_br %arg1, ^bb1, ^bb2
^bb1:
// CHECK: ^[[BB1]]:
// CHECK: br ^[[BB2]]
br ^bb2
// CHECK: cf.br ^[[BB2]]
cf.br ^bb2
^bb2:
// CHECK: ^[[BB2]]:
// CHECK: async.runtime.await %[[TOKEN]]
@ -88,10 +88,10 @@ func @token_coro_return() -> !async.token {
async.runtime.resume %hdl
async.coro.suspend %saved, ^suspend, ^resume, ^cleanup
^resume:
br ^cleanup
cf.br ^cleanup
^cleanup:
async.coro.free %id, %hdl
br ^suspend
cf.br ^suspend
^suspend:
async.coro.end %hdl
return %token : !async.token
@ -109,10 +109,10 @@ func @token_coro_await_and_resume(%arg0: !async.token) -> !async.token {
// CHECK-NEXT: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i64}
async.coro.suspend %saved, ^suspend, ^resume, ^cleanup
^resume:
br ^cleanup
cf.br ^cleanup
^cleanup:
async.coro.free %id, %hdl
br ^suspend
cf.br ^suspend
^suspend:
async.coro.end %hdl
return %token : !async.token
@ -137,10 +137,10 @@ func @value_coro_await_and_resume(%arg0: !async.value<f32>) -> !async.token {
%0 = async.runtime.load %arg0 : !async.value<f32>
// CHECK: arith.addf %[[LOADED]], %[[LOADED]]
%1 = arith.addf %0, %0 : f32
br ^cleanup
cf.br ^cleanup
^cleanup:
async.coro.free %id, %hdl
br ^suspend
cf.br ^suspend
^suspend:
async.coro.end %hdl
return %token : !async.token
@ -167,12 +167,12 @@ func private @outlined_async_execute(%arg0: !async.token) -> !async.token {
// CHECK: ^[[RESUME_1:.*]]:
// CHECK: async.runtime.set_available
async.runtime.set_available %0 : !async.token
br ^cleanup
cf.br ^cleanup
^cleanup:
// CHECK: ^[[CLEANUP:.*]]:
// CHECK: async.coro.free
async.coro.free %1, %2
br ^suspend
cf.br ^suspend
^suspend:
// CHECK: ^[[SUSPEND:.*]]:
// CHECK: async.coro.end
@ -198,7 +198,7 @@ func @token_await_inside_nested_region(%arg0: i1) {
// CHECK-LABEL: @token_defined_in_the_loop
func @token_defined_in_the_loop() {
br ^bb1
cf.br ^bb1
^bb1:
// CHECK: ^[[BB1:.*]]:
// CHECK: %[[TOKEN:.*]] = call @token()
@ -207,7 +207,7 @@ func @token_defined_in_the_loop() {
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i64}
async.runtime.await %token : !async.token
%0 = call @cond(): () -> (i1)
cond_br %0, ^bb1, ^bb2
cf.cond_br %0, ^bb1, ^bb2
^bb2:
// CHECK: ^[[BB2:.*]]:
// CHECK: return
@ -218,18 +218,18 @@ func @token_defined_in_the_loop() {
func @divergent_liveness_one_token(%arg0 : i1) {
// CHECK: %[[TOKEN:.*]] = call @token()
%token = call @token() : () -> !async.token
// CHECK: cond_br %arg0, ^[[LIVE_IN:.*]], ^[[REF_COUNTING:.*]]
cond_br %arg0, ^bb1, ^bb2
// CHECK: cf.cond_br %arg0, ^[[LIVE_IN:.*]], ^[[REF_COUNTING:.*]]
cf.cond_br %arg0, ^bb1, ^bb2
^bb1:
// CHECK: ^[[LIVE_IN]]:
// CHECK: async.runtime.await %[[TOKEN]]
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i64}
// CHECK: br ^[[RETURN:.*]]
// CHECK: cf.br ^[[RETURN:.*]]
async.runtime.await %token : !async.token
br ^bb2
cf.br ^bb2
// CHECK: ^[[REF_COUNTING:.*]]:
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i64}
// CHECK: br ^[[RETURN:.*]]
// CHECK: cf.br ^[[RETURN:.*]]
^bb2:
// CHECK: ^[[RETURN]]:
// CHECK: return
@ -240,20 +240,20 @@ func @divergent_liveness_one_token(%arg0 : i1) {
func @divergent_liveness_unique_predecessor(%arg0 : i1) {
// CHECK: %[[TOKEN:.*]] = call @token()
%token = call @token() : () -> !async.token
// CHECK: cond_br %arg0, ^[[LIVE_IN:.*]], ^[[NO_LIVE_IN:.*]]
cond_br %arg0, ^bb2, ^bb1
// CHECK: cf.cond_br %arg0, ^[[LIVE_IN:.*]], ^[[NO_LIVE_IN:.*]]
cf.cond_br %arg0, ^bb2, ^bb1
^bb1:
// CHECK: ^[[NO_LIVE_IN]]:
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i64}
// CHECK: br ^[[RETURN:.*]]
br ^bb3
// CHECK: cf.br ^[[RETURN:.*]]
cf.br ^bb3
^bb2:
// CHECK: ^[[LIVE_IN]]:
// CHECK: async.runtime.await %[[TOKEN]]
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i64}
// CHECK: br ^[[RETURN]]
// CHECK: cf.br ^[[RETURN]]
async.runtime.await %token : !async.token
br ^bb3
cf.br ^bb3
^bb3:
// CHECK: ^[[RETURN]]:
// CHECK: return
@ -266,24 +266,24 @@ func @divergent_liveness_two_tokens(%arg0 : i1) {
// CHECK: %[[TOKEN1:.*]] = call @token()
%token0 = call @token() : () -> !async.token
%token1 = call @token() : () -> !async.token
// CHECK: cond_br %arg0, ^[[AWAIT0:.*]], ^[[AWAIT1:.*]]
cond_br %arg0, ^await0, ^await1
// CHECK: cf.cond_br %arg0, ^[[AWAIT0:.*]], ^[[AWAIT1:.*]]
cf.cond_br %arg0, ^await0, ^await1
^await0:
// CHECK: ^[[AWAIT0]]:
// CHECK: async.runtime.drop_ref %[[TOKEN1]] {count = 1 : i64}
// CHECK: async.runtime.await %[[TOKEN0]]
// CHECK: async.runtime.drop_ref %[[TOKEN0]] {count = 1 : i64}
// CHECK: br ^[[RETURN:.*]]
// CHECK: cf.br ^[[RETURN:.*]]
async.runtime.await %token0 : !async.token
br ^ret
cf.br ^ret
^await1:
// CHECK: ^[[AWAIT1]]:
// CHECK: async.runtime.drop_ref %[[TOKEN0]] {count = 1 : i64}
// CHECK: async.runtime.await %[[TOKEN1]]
// CHECK: async.runtime.drop_ref %[[TOKEN1]] {count = 1 : i64}
// CHECK: br ^[[RETURN]]
// CHECK: cf.br ^[[RETURN]]
async.runtime.await %token1 : !async.token
br ^ret
cf.br ^ret
^ret:
// CHECK: ^[[RETURN]]:
// CHECK: return

View File

@ -10,7 +10,7 @@ func @simple_callee(%arg0: f32) -> (f32 {builtin.foo = "bar"}) {
// CHECK: %[[RETURNED_STORAGE:.*]] = async.runtime.create : !async.value<f32>
// CHECK: %[[ID:.*]] = async.coro.id
// CHECK: %[[HDL:.*]] = async.coro.begin %[[ID]]
// CHECK: br ^[[ORIGINAL_ENTRY:.*]]
// CHECK: cf.br ^[[ORIGINAL_ENTRY:.*]]
// CHECK ^[[ORIGINAL_ENTRY]]:
// CHECK: %[[VAL:.*]] = arith.addf %[[ARG]], %[[ARG]] : f32
%0 = arith.addf %arg0, %arg0 : f32
@ -29,7 +29,7 @@ func @simple_callee(%arg0: f32) -> (f32 {builtin.foo = "bar"}) {
// CHECK: ^[[RESUME]]:
// CHECK: %[[IS_ERROR:.*]] = async.runtime.is_error %[[VAL_STORAGE]] : !async.value<f32>
// CHECK: cond_br %[[IS_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_OK:.*]]
// CHECK: cf.cond_br %[[IS_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_OK:.*]]
// CHECK: ^[[BRANCH_OK]]:
// CHECK: %[[LOADED:.*]] = async.runtime.load %[[VAL_STORAGE]] : <f32>
@ -37,19 +37,19 @@ func @simple_callee(%arg0: f32) -> (f32 {builtin.foo = "bar"}) {
// CHECK: async.runtime.store %[[RETURNED]], %[[RETURNED_STORAGE]] : <f32>
// CHECK: async.runtime.set_available %[[RETURNED_STORAGE]]
// CHECK: async.runtime.set_available %[[TOKEN]]
// CHECK: br ^[[CLEANUP]]
// CHECK: cf.br ^[[CLEANUP]]
%3 = arith.mulf %arg0, %2 : f32
return %3: f32
// CHECK: ^[[BRANCH_ERROR]]:
// CHECK: async.runtime.set_error %[[TOKEN]]
// CHECK: async.runtime.set_error %[[RETURNED_STORAGE]]
// CHECK: br ^[[CLEANUP]]
// CHECK: cf.br ^[[CLEANUP]]
// CHECK: ^[[CLEANUP]]:
// CHECK: async.coro.free %[[ID]], %[[HDL]]
// CHECK: br ^[[SUSPEND]]
// CHECK: cf.br ^[[SUSPEND]]
// CHECK: ^[[SUSPEND]]:
// CHECK: async.coro.end %[[HDL]]
@ -63,7 +63,7 @@ func @simple_caller() -> f32 {
// CHECK: %[[RETURNED_STORAGE:.*]] = async.runtime.create : !async.value<f32>
// CHECK: %[[ID:.*]] = async.coro.id
// CHECK: %[[HDL:.*]] = async.coro.begin %[[ID]]
// CHECK: br ^[[ORIGINAL_ENTRY:.*]]
// CHECK: cf.br ^[[ORIGINAL_ENTRY:.*]]
// CHECK ^[[ORIGINAL_ENTRY]]:
// CHECK: %[[CONSTANT:.*]] = arith.constant
@ -77,28 +77,28 @@ func @simple_caller() -> f32 {
// CHECK: ^[[RESUME]]:
// CHECK: %[[IS_TOKEN_ERROR:.*]] = async.runtime.is_error %[[RETURNED_TO_CALLER]]#0 : !async.token
// CHECK: cond_br %[[IS_TOKEN_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_TOKEN_OK:.*]]
// CHECK: cf.cond_br %[[IS_TOKEN_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_TOKEN_OK:.*]]
// CHECK: ^[[BRANCH_TOKEN_OK]]:
// CHECK: %[[IS_VALUE_ERROR:.*]] = async.runtime.is_error %[[RETURNED_TO_CALLER]]#1 : !async.value<f32>
// CHECK: cond_br %[[IS_VALUE_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_VALUE_OK:.*]]
// CHECK: cf.cond_br %[[IS_VALUE_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_VALUE_OK:.*]]
// CHECK: ^[[BRANCH_VALUE_OK]]:
// CHECK: %[[LOADED:.*]] = async.runtime.load %[[RETURNED_TO_CALLER]]#1 : <f32>
// CHECK: async.runtime.store %[[LOADED]], %[[RETURNED_STORAGE]] : <f32>
// CHECK: async.runtime.set_available %[[RETURNED_STORAGE]]
// CHECK: async.runtime.set_available %[[TOKEN]]
// CHECK: br ^[[CLEANUP]]
// CHECK: cf.br ^[[CLEANUP]]
return %r: f32
// CHECK: ^[[BRANCH_ERROR]]:
// CHECK: async.runtime.set_error %[[TOKEN]]
// CHECK: async.runtime.set_error %[[RETURNED_STORAGE]]
// CHECK: br ^[[CLEANUP]]
// CHECK: cf.br ^[[CLEANUP]]
// CHECK: ^[[CLEANUP]]:
// CHECK: async.coro.free %[[ID]], %[[HDL]]
// CHECK: br ^[[SUSPEND]]
// CHECK: cf.br ^[[SUSPEND]]
// CHECK: ^[[SUSPEND]]:
// CHECK: async.coro.end %[[HDL]]
@ -112,7 +112,7 @@ func @double_caller() -> f32 {
// CHECK: %[[RETURNED_STORAGE:.*]] = async.runtime.create : !async.value<f32>
// CHECK: %[[ID:.*]] = async.coro.id
// CHECK: %[[HDL:.*]] = async.coro.begin %[[ID]]
// CHECK: br ^[[ORIGINAL_ENTRY:.*]]
// CHECK: cf.br ^[[ORIGINAL_ENTRY:.*]]
// CHECK ^[[ORIGINAL_ENTRY]]:
// CHECK: %[[CONSTANT:.*]] = arith.constant
@ -126,11 +126,11 @@ func @double_caller() -> f32 {
// CHECK: ^[[RESUME_1]]:
// CHECK: %[[IS_TOKEN_ERROR_1:.*]] = async.runtime.is_error %[[RETURNED_TO_CALLER_1]]#0 : !async.token
// CHECK: cond_br %[[IS_TOKEN_ERROR_1]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_TOKEN_OK_1:.*]]
// CHECK: cf.cond_br %[[IS_TOKEN_ERROR_1]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_TOKEN_OK_1:.*]]
// CHECK: ^[[BRANCH_TOKEN_OK_1]]:
// CHECK: %[[IS_VALUE_ERROR_1:.*]] = async.runtime.is_error %[[RETURNED_TO_CALLER_1]]#1 : !async.value<f32>
// CHECK: cond_br %[[IS_VALUE_ERROR_1]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_VALUE_OK_1:.*]]
// CHECK: cf.cond_br %[[IS_VALUE_ERROR_1]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_VALUE_OK_1:.*]]
// CHECK: ^[[BRANCH_VALUE_OK_1]]:
// CHECK: %[[LOADED_1:.*]] = async.runtime.load %[[RETURNED_TO_CALLER_1]]#1 : <f32>
@ -143,27 +143,27 @@ func @double_caller() -> f32 {
// CHECK: ^[[RESUME_2]]:
// CHECK: %[[IS_TOKEN_ERROR_2:.*]] = async.runtime.is_error %[[RETURNED_TO_CALLER_2]]#0 : !async.token
// CHECK: cond_br %[[IS_TOKEN_ERROR_2]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_TOKEN_OK_2:.*]]
// CHECK: cf.cond_br %[[IS_TOKEN_ERROR_2]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_TOKEN_OK_2:.*]]
// CHECK: ^[[BRANCH_TOKEN_OK_2]]:
// CHECK: %[[IS_VALUE_ERROR_2:.*]] = async.runtime.is_error %[[RETURNED_TO_CALLER_2]]#1 : !async.value<f32>
// CHECK: cond_br %[[IS_VALUE_ERROR_2]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_VALUE_OK_2:.*]]
// CHECK: cf.cond_br %[[IS_VALUE_ERROR_2]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_VALUE_OK_2:.*]]
// CHECK: ^[[BRANCH_VALUE_OK_2]]:
// CHECK: %[[LOADED_2:.*]] = async.runtime.load %[[RETURNED_TO_CALLER_2]]#1 : <f32>
// CHECK: async.runtime.store %[[LOADED_2]], %[[RETURNED_STORAGE]] : <f32>
// CHECK: async.runtime.set_available %[[RETURNED_STORAGE]]
// CHECK: async.runtime.set_available %[[TOKEN]]
// CHECK: br ^[[CLEANUP]]
// CHECK: cf.br ^[[CLEANUP]]
return %s: f32
// CHECK: ^[[BRANCH_ERROR]]:
// CHECK: async.runtime.set_error %[[TOKEN]]
// CHECK: async.runtime.set_error %[[RETURNED_STORAGE]]
// CHECK: br ^[[CLEANUP]]
// CHECK: cf.br ^[[CLEANUP]]
// CHECK: ^[[CLEANUP]]:
// CHECK: async.coro.free %[[ID]], %[[HDL]]
// CHECK: br ^[[SUSPEND]]
// CHECK: cf.br ^[[SUSPEND]]
// CHECK: ^[[SUSPEND]]:
// CHECK: async.coro.end %[[HDL]]
@ -184,7 +184,7 @@ func @recursive(%arg: !async.token) {
async.await %arg : !async.token
// CHECK: ^[[RESUME_1]]:
// CHECK: %[[IS_ERROR:.*]] = async.runtime.is_error %[[ARG]] : !async.token
// CHECK: cond_br %[[IS_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_OK:.*]]
// CHECK: cf.cond_br %[[IS_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_OK:.*]]
// CHECK: ^[[BRANCH_OK]]:
// CHECK: %[[GIVEN:.*]] = async.runtime.create : !async.token
@ -200,16 +200,16 @@ call @recursive(%r): (!async.token) -> ()
// CHECK: ^[[RESUME_2]]:
// CHECK: async.runtime.set_available %[[TOKEN]]
// CHECK: br ^[[CLEANUP]]
// CHECK: cf.br ^[[CLEANUP]]
// CHECK: ^[[BRANCH_ERROR]]:
// CHECK: async.runtime.set_error %[[TOKEN]]
// CHECK: br ^[[CLEANUP]]
// CHECK: cf.br ^[[CLEANUP]]
return
// CHECK: ^[[CLEANUP]]:
// CHECK: async.coro.free %[[ID]], %[[HDL]]
// CHECK: br ^[[SUSPEND]]
// CHECK: cf.br ^[[SUSPEND]]
// CHECK: ^[[SUSPEND]]:
// CHECK: async.coro.end %[[HDL]]
@ -230,7 +230,7 @@ func @corecursive1(%arg: !async.token) {
async.await %arg : !async.token
// CHECK: ^[[RESUME_1]]:
// CHECK: %[[IS_ERROR:.*]] = async.runtime.is_error %[[ARG]] : !async.token
// CHECK: cond_br %[[IS_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_OK:.*]]
// CHECK: cf.cond_br %[[IS_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_OK:.*]]
// CHECK: ^[[BRANCH_OK]]:
// CHECK: %[[GIVEN:.*]] = async.runtime.create : !async.token
@ -246,16 +246,16 @@ call @corecursive2(%r): (!async.token) -> ()
// CHECK: ^[[RESUME_2]]:
// CHECK: async.runtime.set_available %[[TOKEN]]
// CHECK: br ^[[CLEANUP]]
// CHECK: cf.br ^[[CLEANUP]]
// CHECK: ^[[BRANCH_ERROR]]:
// CHECK: async.runtime.set_error %[[TOKEN]]
// CHECK: br ^[[CLEANUP]]
// CHECK: cf.br ^[[CLEANUP]]
return
// CHECK: ^[[CLEANUP]]:
// CHECK: async.coro.free %[[ID]], %[[HDL]]
// CHECK: br ^[[SUSPEND]]
// CHECK: cf.br ^[[SUSPEND]]
// CHECK: ^[[SUSPEND]]:
// CHECK: async.coro.end %[[HDL]]
@ -276,7 +276,7 @@ func @corecursive2(%arg: !async.token) {
async.await %arg : !async.token
// CHECK: ^[[RESUME_1]]:
// CHECK: %[[IS_ERROR:.*]] = async.runtime.is_error %[[ARG]] : !async.token
// CHECK: cond_br %[[IS_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_OK:.*]]
// CHECK: cf.cond_br %[[IS_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_OK:.*]]
// CHECK: ^[[BRANCH_OK]]:
// CHECK: %[[GIVEN:.*]] = async.runtime.create : !async.token
@ -292,16 +292,16 @@ call @corecursive1(%r): (!async.token) -> ()
// CHECK: ^[[RESUME_2]]:
// CHECK: async.runtime.set_available %[[TOKEN]]
// CHECK: br ^[[CLEANUP]]
// CHECK: cf.br ^[[CLEANUP]]
// CHECK: ^[[BRANCH_ERROR]]:
// CHECK: async.runtime.set_error %[[TOKEN]]
// CHECK: br ^[[CLEANUP]]
// CHECK: cf.br ^[[CLEANUP]]
return
// CHECK: ^[[CLEANUP]]:
// CHECK: async.coro.free %[[ID]], %[[HDL]]
// CHECK: br ^[[SUSPEND]]
// CHECK: cf.br ^[[SUSPEND]]
// CHECK: ^[[SUSPEND]]:
// CHECK: async.coro.end %[[HDL]]

View File

@ -63,7 +63,7 @@ func @nested_async_execute(%arg0: f32, %arg1: f32, %arg2: memref<1xf32>) {
// CHECK: %[[IS_ERROR:.*]] = async.runtime.is_error %[[TOKEN]]
// CHECK: %[[TRUE:.*]] = arith.constant true
// CHECK: %[[NOT_ERROR:.*]] = arith.xori %[[IS_ERROR]], %[[TRUE]] : i1
// CHECK: assert %[[NOT_ERROR]]
// CHECK: cf.assert %[[NOT_ERROR]]
// CHECK-NEXT: return
async.await %token0 : !async.token
return
@ -109,7 +109,7 @@ func @nested_async_execute(%arg0: f32, %arg1: f32, %arg2: memref<1xf32>) {
// Check the error of the awaited token after resumption.
// CHECK: ^[[RESUME_1]]:
// CHECK: %[[ERR:.*]] = async.runtime.is_error %[[INNER_TOKEN]]
// CHECK: cond_br %[[ERR]], ^[[SET_ERROR:.*]], ^[[CONTINUATION:.*]]
// CHECK: cf.cond_br %[[ERR]], ^[[SET_ERROR:.*]], ^[[CONTINUATION:.*]]
// Set token available if the token is not in the error state.
// CHECK: ^[[CONTINUATION:.*]]:
@ -169,7 +169,7 @@ func @async_execute_token_dependency(%arg0: f32, %arg1: memref<1xf32>) {
// Check the error of the awaited token after resumption.
// CHECK: ^[[RESUME_1]]:
// CHECK: %[[ERR:.*]] = async.runtime.is_error %[[ARG0]]
// CHECK: cond_br %[[ERR]], ^[[SET_ERROR:.*]], ^[[CONTINUATION:.*]]
// CHECK: cf.cond_br %[[ERR]], ^[[SET_ERROR:.*]], ^[[CONTINUATION:.*]]
// Emplace result token after second resumption and error checking.
// CHECK: ^[[CONTINUATION:.*]]:
@ -225,7 +225,7 @@ func @async_group_await_all(%arg0: f32, %arg1: memref<1xf32>) {
// Check the error of the awaited token after resumption.
// CHECK: ^[[RESUME_1]]:
// CHECK: %[[ERR:.*]] = async.runtime.is_error %[[ARG]]
// CHECK: cond_br %[[ERR]], ^[[SET_ERROR:.*]], ^[[CONTINUATION:.*]]
// CHECK: cf.cond_br %[[ERR]], ^[[SET_ERROR:.*]], ^[[CONTINUATION:.*]]
// Emplace result token after error checking.
// CHECK: ^[[CONTINUATION:.*]]:
@ -319,7 +319,7 @@ func @async_value_operands() {
// Check the error of the awaited token after resumption.
// CHECK: ^[[RESUME_1]]:
// CHECK: %[[ERR:.*]] = async.runtime.is_error %[[ARG]]
// CHECK: cond_br %[[ERR]], ^[[SET_ERROR:.*]], ^[[CONTINUATION:.*]]
// CHECK: cf.cond_br %[[ERR]], ^[[SET_ERROR:.*]], ^[[CONTINUATION:.*]]
// // Load from the async.value argument after error checking.
// CHECK: ^[[CONTINUATION:.*]]:
@ -335,7 +335,7 @@ func @async_value_operands() {
// CHECK-LABEL: @execute_assertion
func @execute_assertion(%arg0: i1) {
%token = async.execute {
assert %arg0, "error"
cf.assert %arg0, "error"
async.yield
}
async.await %token : !async.token
@ -358,17 +358,17 @@ func @execute_assertion(%arg0: i1) {
// Resume coroutine after suspension.
// CHECK: ^[[RESUME]]:
// CHECK: cond_br %[[ARG0]], ^[[SET_AVAILABLE:.*]], ^[[SET_ERROR:.*]]
// CHECK: cf.cond_br %[[ARG0]], ^[[SET_AVAILABLE:.*]], ^[[SET_ERROR:.*]]
// Set coroutine completion token to available state.
// CHECK: ^[[SET_AVAILABLE]]:
// CHECK: async.runtime.set_available %[[TOKEN]]
// CHECK: br ^[[CLEANUP]]
// CHECK: cf.br ^[[CLEANUP]]
// Set coroutine completion token to error state.
// CHECK: ^[[SET_ERROR]]:
// CHECK: async.runtime.set_error %[[TOKEN]]
// CHECK: br ^[[CLEANUP]]
// CHECK: cf.br ^[[CLEANUP]]
// Delete coroutine.
// CHECK: ^[[CLEANUP]]:
@ -409,7 +409,7 @@ func @lower_scf_to_cfg(%arg0: f32, %arg1: memref<1xf32>, %arg2: i1) {
// Check that structured control flow lowered to CFG.
// CHECK-NOT: scf.if
// CHECK: cond_br %[[FLAG]]
// CHECK: cf.cond_br %[[FLAG]]
// -----
// Constants captured by the async.execute region should be cloned into the

View File

@ -17,26 +17,26 @@
// CHECK-LABEL: func @condBranch
func @condBranch(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
cond_br %arg0, ^bb1, ^bb2
cf.cond_br %arg0, ^bb1, ^bb2
^bb1:
br ^bb3(%arg1 : memref<2xf32>)
cf.br ^bb3(%arg1 : memref<2xf32>)
^bb2:
%0 = memref.alloc() : memref<2xf32>
test.buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>)
br ^bb3(%0 : memref<2xf32>)
cf.br ^bb3(%0 : memref<2xf32>)
^bb3(%1: memref<2xf32>):
test.copy(%1, %arg2) : (memref<2xf32>, memref<2xf32>)
return
}
// CHECK-NEXT: cond_br
// CHECK-NEXT: cf.cond_br
// CHECK: %[[ALLOC0:.*]] = bufferization.clone
// CHECK-NEXT: br ^bb3(%[[ALLOC0]]
// CHECK-NEXT: cf.br ^bb3(%[[ALLOC0]]
// CHECK: %[[ALLOC1:.*]] = memref.alloc
// CHECK-NEXT: test.buffer_based
// CHECK-NEXT: %[[ALLOC2:.*]] = bufferization.clone %[[ALLOC1]]
// CHECK-NEXT: memref.dealloc %[[ALLOC1]]
// CHECK-NEXT: br ^bb3(%[[ALLOC2]]
// CHECK-NEXT: cf.br ^bb3(%[[ALLOC2]]
// CHECK: test.copy
// CHECK-NEXT: memref.dealloc
// CHECK-NEXT: return
@ -62,27 +62,27 @@ func @condBranchDynamicType(
%arg1: memref<?xf32>,
%arg2: memref<?xf32>,
%arg3: index) {
cond_br %arg0, ^bb1, ^bb2(%arg3: index)
cf.cond_br %arg0, ^bb1, ^bb2(%arg3: index)
^bb1:
br ^bb3(%arg1 : memref<?xf32>)
cf.br ^bb3(%arg1 : memref<?xf32>)
^bb2(%0: index):
%1 = memref.alloc(%0) : memref<?xf32>
test.buffer_based in(%arg1: memref<?xf32>) out(%1: memref<?xf32>)
br ^bb3(%1 : memref<?xf32>)
cf.br ^bb3(%1 : memref<?xf32>)
^bb3(%2: memref<?xf32>):
test.copy(%2, %arg2) : (memref<?xf32>, memref<?xf32>)
return
}
// CHECK-NEXT: cond_br
// CHECK-NEXT: cf.cond_br
// CHECK: %[[ALLOC0:.*]] = bufferization.clone
// CHECK-NEXT: br ^bb3(%[[ALLOC0]]
// CHECK-NEXT: cf.br ^bb3(%[[ALLOC0]]
// CHECK: ^bb2(%[[IDX:.*]]:{{.*}})
// CHECK-NEXT: %[[ALLOC1:.*]] = memref.alloc(%[[IDX]])
// CHECK-NEXT: test.buffer_based
// CHECK-NEXT: %[[ALLOC2:.*]] = bufferization.clone
// CHECK-NEXT: memref.dealloc %[[ALLOC1]]
// CHECK-NEXT: br ^bb3
// CHECK-NEXT: cf.br ^bb3
// CHECK-NEXT: ^bb3(%[[ALLOC3:.*]]:{{.*}})
// CHECK: test.copy(%[[ALLOC3]],
// CHECK-NEXT: memref.dealloc %[[ALLOC3]]
@ -98,28 +98,28 @@ func @condBranchUnrankedType(
%arg1: memref<*xf32>,
%arg2: memref<*xf32>,
%arg3: index) {
cond_br %arg0, ^bb1, ^bb2(%arg3: index)
cf.cond_br %arg0, ^bb1, ^bb2(%arg3: index)
^bb1:
br ^bb3(%arg1 : memref<*xf32>)
cf.br ^bb3(%arg1 : memref<*xf32>)
^bb2(%0: index):
%1 = memref.alloc(%0) : memref<?xf32>
%2 = memref.cast %1 : memref<?xf32> to memref<*xf32>
test.buffer_based in(%arg1: memref<*xf32>) out(%2: memref<*xf32>)
br ^bb3(%2 : memref<*xf32>)
cf.br ^bb3(%2 : memref<*xf32>)
^bb3(%3: memref<*xf32>):
test.copy(%3, %arg2) : (memref<*xf32>, memref<*xf32>)
return
}
// CHECK-NEXT: cond_br
// CHECK-NEXT: cf.cond_br
// CHECK: %[[ALLOC0:.*]] = bufferization.clone
// CHECK-NEXT: br ^bb3(%[[ALLOC0]]
// CHECK-NEXT: cf.br ^bb3(%[[ALLOC0]]
// CHECK: ^bb2(%[[IDX:.*]]:{{.*}})
// CHECK-NEXT: %[[ALLOC1:.*]] = memref.alloc(%[[IDX]])
// CHECK: test.buffer_based
// CHECK-NEXT: %[[ALLOC2:.*]] = bufferization.clone
// CHECK-NEXT: memref.dealloc %[[ALLOC1]]
// CHECK-NEXT: br ^bb3
// CHECK-NEXT: cf.br ^bb3
// CHECK-NEXT: ^bb3(%[[ALLOC3:.*]]:{{.*}})
// CHECK: test.copy(%[[ALLOC3]],
// CHECK-NEXT: memref.dealloc %[[ALLOC3]]
@ -153,44 +153,44 @@ func @condBranchDynamicTypeNested(
%arg1: memref<?xf32>,
%arg2: memref<?xf32>,
%arg3: index) {
cond_br %arg0, ^bb1, ^bb2(%arg3: index)
cf.cond_br %arg0, ^bb1, ^bb2(%arg3: index)
^bb1:
br ^bb6(%arg1 : memref<?xf32>)
cf.br ^bb6(%arg1 : memref<?xf32>)
^bb2(%0: index):
%1 = memref.alloc(%0) : memref<?xf32>
test.buffer_based in(%arg1: memref<?xf32>) out(%1: memref<?xf32>)
cond_br %arg0, ^bb3, ^bb4
cf.cond_br %arg0, ^bb3, ^bb4
^bb3:
br ^bb5(%1 : memref<?xf32>)
cf.br ^bb5(%1 : memref<?xf32>)
^bb4:
br ^bb5(%1 : memref<?xf32>)
cf.br ^bb5(%1 : memref<?xf32>)
^bb5(%2: memref<?xf32>):
br ^bb6(%2 : memref<?xf32>)
cf.br ^bb6(%2 : memref<?xf32>)
^bb6(%3: memref<?xf32>):
br ^bb7(%3 : memref<?xf32>)
cf.br ^bb7(%3 : memref<?xf32>)
^bb7(%4: memref<?xf32>):
test.copy(%4, %arg2) : (memref<?xf32>, memref<?xf32>)
return
}
// CHECK-NEXT: cond_br{{.*}}
// CHECK-NEXT: cf.cond_br{{.*}}
// CHECK-NEXT: ^bb1
// CHECK-NEXT: %[[ALLOC0:.*]] = bufferization.clone
// CHECK-NEXT: br ^bb6(%[[ALLOC0]]
// CHECK-NEXT: cf.br ^bb6(%[[ALLOC0]]
// CHECK: ^bb2(%[[IDX:.*]]:{{.*}})
// CHECK-NEXT: %[[ALLOC1:.*]] = memref.alloc(%[[IDX]])
// CHECK-NEXT: test.buffer_based
// CHECK: cond_br
// CHECK: cf.cond_br
// CHECK: ^bb3:
// CHECK-NEXT: br ^bb5(%[[ALLOC1]]{{.*}})
// CHECK-NEXT: cf.br ^bb5(%[[ALLOC1]]{{.*}})
// CHECK: ^bb4:
// CHECK-NEXT: br ^bb5(%[[ALLOC1]]{{.*}})
// CHECK-NEXT: cf.br ^bb5(%[[ALLOC1]]{{.*}})
// CHECK-NEXT: ^bb5(%[[ALLOC2:.*]]:{{.*}})
// CHECK-NEXT: %[[ALLOC3:.*]] = bufferization.clone %[[ALLOC2]]
// CHECK-NEXT: memref.dealloc %[[ALLOC1]]
// CHECK-NEXT: br ^bb6(%[[ALLOC3]]{{.*}})
// CHECK-NEXT: cf.br ^bb6(%[[ALLOC3]]{{.*}})
// CHECK-NEXT: ^bb6(%[[ALLOC4:.*]]:{{.*}})
// CHECK-NEXT: br ^bb7(%[[ALLOC4]]{{.*}})
// CHECK-NEXT: cf.br ^bb7(%[[ALLOC4]]{{.*}})
// CHECK-NEXT: ^bb7(%[[ALLOC5:.*]]:{{.*}})
// CHECK: test.copy(%[[ALLOC5]],
// CHECK-NEXT: memref.dealloc %[[ALLOC4]]
@ -225,18 +225,18 @@ func @emptyUsesValue(%arg0: memref<4xf32>) {
// CHECK-LABEL: func @criticalEdge
func @criticalEdge(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
cond_br %arg0, ^bb1, ^bb2(%arg1 : memref<2xf32>)
cf.cond_br %arg0, ^bb1, ^bb2(%arg1 : memref<2xf32>)
^bb1:
%0 = memref.alloc() : memref<2xf32>
test.buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>)
br ^bb2(%0 : memref<2xf32>)
cf.br ^bb2(%0 : memref<2xf32>)
^bb2(%1: memref<2xf32>):
test.copy(%1, %arg2) : (memref<2xf32>, memref<2xf32>)
return
}
// CHECK-NEXT: %[[ALLOC0:.*]] = bufferization.clone
// CHECK-NEXT: cond_br
// CHECK-NEXT: cf.cond_br
// CHECK: %[[ALLOC1:.*]] = memref.alloc()
// CHECK-NEXT: test.buffer_based
// CHECK-NEXT: %[[ALLOC2:.*]] = bufferization.clone %[[ALLOC1]]
@ -260,9 +260,9 @@ func @criticalEdge(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
func @invCriticalEdge(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
%0 = memref.alloc() : memref<2xf32>
test.buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>)
cond_br %arg0, ^bb1, ^bb2(%arg1 : memref<2xf32>)
cf.cond_br %arg0, ^bb1, ^bb2(%arg1 : memref<2xf32>)
^bb1:
br ^bb2(%0 : memref<2xf32>)
cf.br ^bb2(%0 : memref<2xf32>)
^bb2(%1: memref<2xf32>):
test.copy(%1, %arg2) : (memref<2xf32>, memref<2xf32>)
return
@ -288,13 +288,13 @@ func @invCriticalEdge(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
func @ifElse(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
%0 = memref.alloc() : memref<2xf32>
test.buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>)
cond_br %arg0,
cf.cond_br %arg0,
^bb1(%arg1, %0 : memref<2xf32>, memref<2xf32>),
^bb2(%0, %arg1 : memref<2xf32>, memref<2xf32>)
^bb1(%1: memref<2xf32>, %2: memref<2xf32>):
br ^bb3(%1, %2 : memref<2xf32>, memref<2xf32>)
cf.br ^bb3(%1, %2 : memref<2xf32>, memref<2xf32>)
^bb2(%3: memref<2xf32>, %4: memref<2xf32>):
br ^bb3(%3, %4 : memref<2xf32>, memref<2xf32>)
cf.br ^bb3(%3, %4 : memref<2xf32>, memref<2xf32>)
^bb3(%5: memref<2xf32>, %6: memref<2xf32>):
%7 = memref.alloc() : memref<2xf32>
test.buffer_based in(%5: memref<2xf32>) out(%7: memref<2xf32>)
@ -326,13 +326,13 @@ func @ifElse(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
func @ifElseNoUsers(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
%0 = memref.alloc() : memref<2xf32>
test.buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>)
cond_br %arg0,
cf.cond_br %arg0,
^bb1(%arg1, %0 : memref<2xf32>, memref<2xf32>),
^bb2(%0, %arg1 : memref<2xf32>, memref<2xf32>)
^bb1(%1: memref<2xf32>, %2: memref<2xf32>):
br ^bb3(%1, %2 : memref<2xf32>, memref<2xf32>)
cf.br ^bb3(%1, %2 : memref<2xf32>, memref<2xf32>)
^bb2(%3: memref<2xf32>, %4: memref<2xf32>):
br ^bb3(%3, %4 : memref<2xf32>, memref<2xf32>)
cf.br ^bb3(%3, %4 : memref<2xf32>, memref<2xf32>)
^bb3(%5: memref<2xf32>, %6: memref<2xf32>):
test.copy(%arg1, %arg2) : (memref<2xf32>, memref<2xf32>)
return
@ -361,17 +361,17 @@ func @ifElseNoUsers(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
func @ifElseNested(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
%0 = memref.alloc() : memref<2xf32>
test.buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>)
cond_br %arg0,
cf.cond_br %arg0,
^bb1(%arg1, %0 : memref<2xf32>, memref<2xf32>),
^bb2(%0, %arg1 : memref<2xf32>, memref<2xf32>)
^bb1(%1: memref<2xf32>, %2: memref<2xf32>):
br ^bb5(%1, %2 : memref<2xf32>, memref<2xf32>)
cf.br ^bb5(%1, %2 : memref<2xf32>, memref<2xf32>)
^bb2(%3: memref<2xf32>, %4: memref<2xf32>):
cond_br %arg0, ^bb3(%3 : memref<2xf32>), ^bb4(%4 : memref<2xf32>)
cf.cond_br %arg0, ^bb3(%3 : memref<2xf32>), ^bb4(%4 : memref<2xf32>)
^bb3(%5: memref<2xf32>):
br ^bb5(%5, %3 : memref<2xf32>, memref<2xf32>)
cf.br ^bb5(%5, %3 : memref<2xf32>, memref<2xf32>)
^bb4(%6: memref<2xf32>):
br ^bb5(%3, %6 : memref<2xf32>, memref<2xf32>)
cf.br ^bb5(%3, %6 : memref<2xf32>, memref<2xf32>)
^bb5(%7: memref<2xf32>, %8: memref<2xf32>):
%9 = memref.alloc() : memref<2xf32>
test.buffer_based in(%7: memref<2xf32>) out(%9: memref<2xf32>)
@ -430,33 +430,33 @@ func @moving_alloc_and_inserting_missing_dealloc(
%cond: i1,
%arg0: memref<2xf32>,
%arg1: memref<2xf32>) {
cond_br %cond, ^bb1, ^bb2
cf.cond_br %cond, ^bb1, ^bb2
^bb1:
%0 = memref.alloc() : memref<2xf32>
test.buffer_based in(%arg0: memref<2xf32>) out(%0: memref<2xf32>)
br ^exit(%0 : memref<2xf32>)
cf.br ^exit(%0 : memref<2xf32>)
^bb2:
%1 = memref.alloc() : memref<2xf32>
test.buffer_based in(%arg0: memref<2xf32>) out(%1: memref<2xf32>)
br ^exit(%1 : memref<2xf32>)
cf.br ^exit(%1 : memref<2xf32>)
^exit(%arg2: memref<2xf32>):
test.copy(%arg2, %arg1) : (memref<2xf32>, memref<2xf32>)
return
}
// CHECK-NEXT: cond_br{{.*}}
// CHECK-NEXT: cf.cond_br{{.*}}
// CHECK-NEXT: ^bb1
// CHECK: %[[ALLOC0:.*]] = memref.alloc()
// CHECK-NEXT: test.buffer_based
// CHECK-NEXT: %[[ALLOC1:.*]] = bufferization.clone %[[ALLOC0]]
// CHECK-NEXT: memref.dealloc %[[ALLOC0]]
// CHECK-NEXT: br ^bb3(%[[ALLOC1]]
// CHECK-NEXT: cf.br ^bb3(%[[ALLOC1]]
// CHECK-NEXT: ^bb2
// CHECK-NEXT: %[[ALLOC2:.*]] = memref.alloc()
// CHECK-NEXT: test.buffer_based
// CHECK-NEXT: %[[ALLOC3:.*]] = bufferization.clone %[[ALLOC2]]
// CHECK-NEXT: memref.dealloc %[[ALLOC2]]
// CHECK-NEXT: br ^bb3(%[[ALLOC3]]
// CHECK-NEXT: cf.br ^bb3(%[[ALLOC3]]
// CHECK-NEXT: ^bb3(%[[ALLOC4:.*]]:{{.*}})
// CHECK: test.copy
// CHECK-NEXT: memref.dealloc %[[ALLOC4]]
@ -480,20 +480,20 @@ func @moving_invalid_dealloc_op_complex(
%arg0: memref<2xf32>,
%arg1: memref<2xf32>) {
%1 = memref.alloc() : memref<2xf32>
cond_br %cond, ^bb1, ^bb2
cf.cond_br %cond, ^bb1, ^bb2
^bb1:
br ^exit(%arg0 : memref<2xf32>)
cf.br ^exit(%arg0 : memref<2xf32>)
^bb2:
test.buffer_based in(%arg0: memref<2xf32>) out(%1: memref<2xf32>)
memref.dealloc %1 : memref<2xf32>
br ^exit(%1 : memref<2xf32>)
cf.br ^exit(%1 : memref<2xf32>)
^exit(%arg2: memref<2xf32>):
test.copy(%arg2, %arg1) : (memref<2xf32>, memref<2xf32>)
return
}
// CHECK-NEXT: %[[ALLOC0:.*]] = memref.alloc()
// CHECK-NEXT: cond_br
// CHECK-NEXT: cf.cond_br
// CHECK: test.copy
// CHECK-NEXT: memref.dealloc %[[ALLOC0]]
// CHECK-NEXT: return
@ -548,9 +548,9 @@ func @nested_regions_and_cond_branch(
%arg0: i1,
%arg1: memref<2xf32>,
%arg2: memref<2xf32>) {
cond_br %arg0, ^bb1, ^bb2
cf.cond_br %arg0, ^bb1, ^bb2
^bb1:
br ^bb3(%arg1 : memref<2xf32>)
cf.br ^bb3(%arg1 : memref<2xf32>)
^bb2:
%0 = memref.alloc() : memref<2xf32>
test.region_buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>) {
@ -560,13 +560,13 @@ func @nested_regions_and_cond_branch(
%tmp1 = math.exp %gen1_arg0 : f32
test.region_yield %tmp1 : f32
}
br ^bb3(%0 : memref<2xf32>)
cf.br ^bb3(%0 : memref<2xf32>)
^bb3(%1: memref<2xf32>):
test.copy(%1, %arg2) : (memref<2xf32>, memref<2xf32>)
return
}
// CHECK: (%[[cond:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %{{.*}}: {{.*}})
// CHECK-NEXT: cond_br %[[cond]], ^[[BB1:.*]], ^[[BB2:.*]]
// CHECK-NEXT: cf.cond_br %[[cond]], ^[[BB1:.*]], ^[[BB2:.*]]
// CHECK: %[[ALLOC0:.*]] = bufferization.clone %[[ARG1]]
// CHECK: ^[[BB2]]:
// CHECK: %[[ALLOC1:.*]] = memref.alloc()
@ -728,21 +728,21 @@ func @subview(%arg0 : index, %arg1 : index, %arg2 : memref<?x?xf32>) {
// CHECK-LABEL: func @condBranchAlloca
func @condBranchAlloca(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
cond_br %arg0, ^bb1, ^bb2
cf.cond_br %arg0, ^bb1, ^bb2
^bb1:
br ^bb3(%arg1 : memref<2xf32>)
cf.br ^bb3(%arg1 : memref<2xf32>)
^bb2:
%0 = memref.alloca() : memref<2xf32>
test.buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>)
br ^bb3(%0 : memref<2xf32>)
cf.br ^bb3(%0 : memref<2xf32>)
^bb3(%1: memref<2xf32>):
test.copy(%1, %arg2) : (memref<2xf32>, memref<2xf32>)
return
}
// CHECK-NEXT: cond_br
// CHECK-NEXT: cf.cond_br
// CHECK: %[[ALLOCA:.*]] = memref.alloca()
// CHECK: br ^bb3(%[[ALLOCA:.*]])
// CHECK: cf.br ^bb3(%[[ALLOCA:.*]])
// CHECK-NEXT: ^bb3
// CHECK-NEXT: test.copy
// CHECK-NEXT: return
@ -757,13 +757,13 @@ func @condBranchAlloca(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
func @ifElseAlloca(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
%0 = memref.alloc() : memref<2xf32>
test.buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>)
cond_br %arg0,
cf.cond_br %arg0,
^bb1(%arg1, %0 : memref<2xf32>, memref<2xf32>),
^bb2(%0, %arg1 : memref<2xf32>, memref<2xf32>)
^bb1(%1: memref<2xf32>, %2: memref<2xf32>):
br ^bb3(%1, %2 : memref<2xf32>, memref<2xf32>)
cf.br ^bb3(%1, %2 : memref<2xf32>, memref<2xf32>)
^bb2(%3: memref<2xf32>, %4: memref<2xf32>):
br ^bb3(%3, %4 : memref<2xf32>, memref<2xf32>)
cf.br ^bb3(%3, %4 : memref<2xf32>, memref<2xf32>)
^bb3(%5: memref<2xf32>, %6: memref<2xf32>):
%7 = memref.alloca() : memref<2xf32>
test.buffer_based in(%5: memref<2xf32>) out(%7: memref<2xf32>)
@ -788,17 +788,17 @@ func @ifElseNestedAlloca(
%arg2: memref<2xf32>) {
%0 = memref.alloca() : memref<2xf32>
test.buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>)
cond_br %arg0,
cf.cond_br %arg0,
^bb1(%arg1, %0 : memref<2xf32>, memref<2xf32>),
^bb2(%0, %arg1 : memref<2xf32>, memref<2xf32>)
^bb1(%1: memref<2xf32>, %2: memref<2xf32>):
br ^bb5(%1, %2 : memref<2xf32>, memref<2xf32>)
cf.br ^bb5(%1, %2 : memref<2xf32>, memref<2xf32>)
^bb2(%3: memref<2xf32>, %4: memref<2xf32>):
cond_br %arg0, ^bb3(%3 : memref<2xf32>), ^bb4(%4 : memref<2xf32>)
cf.cond_br %arg0, ^bb3(%3 : memref<2xf32>), ^bb4(%4 : memref<2xf32>)
^bb3(%5: memref<2xf32>):
br ^bb5(%5, %3 : memref<2xf32>, memref<2xf32>)
cf.br ^bb5(%5, %3 : memref<2xf32>, memref<2xf32>)
^bb4(%6: memref<2xf32>):
br ^bb5(%3, %6 : memref<2xf32>, memref<2xf32>)
cf.br ^bb5(%3, %6 : memref<2xf32>, memref<2xf32>)
^bb5(%7: memref<2xf32>, %8: memref<2xf32>):
%9 = memref.alloc() : memref<2xf32>
test.buffer_based in(%7: memref<2xf32>) out(%9: memref<2xf32>)
@ -821,9 +821,9 @@ func @nestedRegionsAndCondBranchAlloca(
%arg0: i1,
%arg1: memref<2xf32>,
%arg2: memref<2xf32>) {
cond_br %arg0, ^bb1, ^bb2
cf.cond_br %arg0, ^bb1, ^bb2
^bb1:
br ^bb3(%arg1 : memref<2xf32>)
cf.br ^bb3(%arg1 : memref<2xf32>)
^bb2:
%0 = memref.alloc() : memref<2xf32>
test.region_buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>) {
@ -833,13 +833,13 @@ func @nestedRegionsAndCondBranchAlloca(
%tmp1 = math.exp %gen1_arg0 : f32
test.region_yield %tmp1 : f32
}
br ^bb3(%0 : memref<2xf32>)
cf.br ^bb3(%0 : memref<2xf32>)
^bb3(%1: memref<2xf32>):
test.copy(%1, %arg2) : (memref<2xf32>, memref<2xf32>)
return
}
// CHECK: (%[[cond:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %{{.*}}: {{.*}})
// CHECK-NEXT: cond_br %[[cond]], ^[[BB1:.*]], ^[[BB2:.*]]
// CHECK-NEXT: cf.cond_br %[[cond]], ^[[BB1:.*]], ^[[BB2:.*]]
// CHECK: ^[[BB1]]:
// CHECK: %[[ALLOC0:.*]] = bufferization.clone
// CHECK: ^[[BB2]]:
@ -1103,11 +1103,11 @@ func @loop_dynalloc(
%arg2: memref<?xf32>,
%arg3: memref<?xf32>) {
%const0 = arith.constant 0 : i32
br ^loopHeader(%const0, %arg2 : i32, memref<?xf32>)
cf.br ^loopHeader(%const0, %arg2 : i32, memref<?xf32>)
^loopHeader(%i : i32, %buff : memref<?xf32>):
%lessThan = arith.cmpi slt, %i, %arg1 : i32
cond_br %lessThan,
cf.cond_br %lessThan,
^loopBody(%i, %buff : i32, memref<?xf32>),
^exit(%buff : memref<?xf32>)
@ -1116,7 +1116,7 @@ func @loop_dynalloc(
%inc = arith.addi %val, %const1 : i32
%size = arith.index_cast %inc : i32 to index
%alloc1 = memref.alloc(%size) : memref<?xf32>
br ^loopHeader(%inc, %alloc1 : i32, memref<?xf32>)
cf.br ^loopHeader(%inc, %alloc1 : i32, memref<?xf32>)
^exit(%buff3 : memref<?xf32>):
test.copy(%buff3, %arg3) : (memref<?xf32>, memref<?xf32>)
@ -1136,17 +1136,17 @@ func @do_loop_alloc(
%arg2: memref<2xf32>,
%arg3: memref<2xf32>) {
%const0 = arith.constant 0 : i32
br ^loopBody(%const0, %arg2 : i32, memref<2xf32>)
cf.br ^loopBody(%const0, %arg2 : i32, memref<2xf32>)
^loopBody(%val : i32, %buff2: memref<2xf32>):
%const1 = arith.constant 1 : i32
%inc = arith.addi %val, %const1 : i32
%alloc1 = memref.alloc() : memref<2xf32>
br ^loopHeader(%inc, %alloc1 : i32, memref<2xf32>)
cf.br ^loopHeader(%inc, %alloc1 : i32, memref<2xf32>)
^loopHeader(%i : i32, %buff : memref<2xf32>):
%lessThan = arith.cmpi slt, %i, %arg1 : i32
cond_br %lessThan,
cf.cond_br %lessThan,
^loopBody(%i, %buff : i32, memref<2xf32>),
^exit(%buff : memref<2xf32>)

Some files were not shown because too many files have changed in this diff Show More