[mlir] Split out a new ControlFlow dialect from Standard

This dialect is intended to model lower level/branch based control-flow constructs. The initial set
of operations are: AssertOp, BranchOp, CondBranchOp, SwitchOp; all split out from the current
standard dialect.

See https://discourse.llvm.org/t/standard-dialect-the-final-chapter/6061

Differential Revision: https://reviews.llvm.org/D118966
This commit is contained in:
River Riddle 2022-02-03 20:59:43 -08:00
parent edca177cbe
commit ace01605e0
239 changed files with 3027 additions and 2585 deletions

View File

@ -27,8 +27,8 @@ namespace fir::support {
#define FLANG_NONCODEGEN_DIALECT_LIST \ #define FLANG_NONCODEGEN_DIALECT_LIST \
mlir::AffineDialect, FIROpsDialect, mlir::acc::OpenACCDialect, \ mlir::AffineDialect, FIROpsDialect, mlir::acc::OpenACCDialect, \
mlir::omp::OpenMPDialect, mlir::scf::SCFDialect, \ mlir::omp::OpenMPDialect, mlir::scf::SCFDialect, \
mlir::arith::ArithmeticDialect, mlir::StandardOpsDialect, \ mlir::arith::ArithmeticDialect, mlir::cf::ControlFlowDialect, \
mlir::vector::VectorDialect mlir::StandardOpsDialect, mlir::vector::VectorDialect
// The definitive list of dialects used by flang. // The definitive list of dialects used by flang.
#define FLANG_DIALECT_LIST \ #define FLANG_DIALECT_LIST \

View File

@ -9,7 +9,7 @@
/// This file defines some shared command-line options that can be used when /// This file defines some shared command-line options that can be used when
/// debugging the test tools. This file must be included into the tool. /// debugging the test tools. This file must be included into the tool.
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h" #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "mlir/Pass/PassManager.h" #include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h" #include "mlir/Transforms/Passes.h"
@ -139,7 +139,7 @@ inline void createDefaultFIROptimizerPassPipeline(mlir::PassManager &pm) {
// convert control flow to CFG form // convert control flow to CFG form
fir::addCfgConversionPass(pm); fir::addCfgConversionPass(pm);
pm.addPass(mlir::createLowerToCFGPass()); pm.addPass(mlir::createConvertSCFToCFPass());
pm.addPass(mlir::createCanonicalizerPass(config)); pm.addPass(mlir::createCanonicalizerPass(config));
} }

View File

@ -32,7 +32,7 @@ add_flang_library(FortranLower
FortranSemantics FortranSemantics
MLIRAffineToStandard MLIRAffineToStandard
MLIRLLVMIR MLIRLLVMIR
MLIRSCFToStandard MLIRSCFToControlFlow
MLIRStandard MLIRStandard
LINK_COMPONENTS LINK_COMPONENTS

View File

@ -18,6 +18,7 @@
#include "flang/Optimizer/Dialect/FIROps.h" #include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/Support/TypeCode.h" #include "flang/Optimizer/Support/TypeCode.h"
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
@ -3293,6 +3294,8 @@ public:
mlir::populateStdToLLVMConversionPatterns(typeConverter, pattern); mlir::populateStdToLLVMConversionPatterns(typeConverter, pattern);
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter, mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
pattern); pattern);
mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter,
pattern);
mlir::ConversionTarget target{*context}; mlir::ConversionTarget target{*context};
target.addLegalDialect<mlir::LLVM::LLVMDialect>(); target.addLegalDialect<mlir::LLVM::LLVMDialect>();

View File

@ -13,6 +13,7 @@
#include "flang/Optimizer/Dialect/FIRDialect.h" #include "flang/Optimizer/Dialect/FIRDialect.h"
#include "flang/Optimizer/Support/FIRContext.h" #include "flang/Optimizer/Support/FIRContext.h"
#include "flang/Optimizer/Transforms/Passes.h" #include "flang/Optimizer/Transforms/Passes.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/Debug.h" #include "llvm/Support/Debug.h"
@ -332,9 +333,9 @@ ArrayCopyAnalysis::arrayAccesses(ArrayLoadOp load) {
<< "add modify {" << *owner << "} to array value set\n"); << "add modify {" << *owner << "} to array value set\n");
accesses.push_back(owner); accesses.push_back(owner);
appendToQueue(update.getResult(1)); appendToQueue(update.getResult(1));
} else if (auto br = mlir::dyn_cast<mlir::BranchOp>(owner)) { } else if (auto br = mlir::dyn_cast<mlir::cf::BranchOp>(owner)) {
branchOp(br.getDest(), br.getDestOperands()); branchOp(br.getDest(), br.getDestOperands());
} else if (auto br = mlir::dyn_cast<mlir::CondBranchOp>(owner)) { } else if (auto br = mlir::dyn_cast<mlir::cf::CondBranchOp>(owner)) {
branchOp(br.getTrueDest(), br.getTrueOperands()); branchOp(br.getTrueDest(), br.getTrueOperands());
branchOp(br.getFalseDest(), br.getFalseOperands()); branchOp(br.getFalseDest(), br.getFalseOperands());
} else if (mlir::isa<ArrayMergeStoreOp>(owner)) { } else if (mlir::isa<ArrayMergeStoreOp>(owner)) {
@ -789,9 +790,9 @@ public:
patterns1.insert<ArrayUpdateConversion>(context, analysis, useMap); patterns1.insert<ArrayUpdateConversion>(context, analysis, useMap);
patterns1.insert<ArrayModifyConversion>(context, analysis, useMap); patterns1.insert<ArrayModifyConversion>(context, analysis, useMap);
mlir::ConversionTarget target(*context); mlir::ConversionTarget target(*context);
target.addLegalDialect<FIROpsDialect, mlir::scf::SCFDialect, target.addLegalDialect<
mlir::arith::ArithmeticDialect, FIROpsDialect, mlir::scf::SCFDialect, mlir::arith::ArithmeticDialect,
mlir::StandardOpsDialect>(); mlir::cf::ControlFlowDialect, mlir::StandardOpsDialect>();
target.addIllegalOp<ArrayFetchOp, ArrayUpdateOp, ArrayModifyOp>(); target.addIllegalOp<ArrayFetchOp, ArrayUpdateOp, ArrayModifyOp>();
// Rewrite the array fetch and array update ops. // Rewrite the array fetch and array update ops.
if (mlir::failed( if (mlir::failed(

View File

@ -11,6 +11,7 @@
#include "flang/Optimizer/Dialect/FIROps.h" #include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/Transforms/Passes.h" #include "flang/Optimizer/Transforms/Passes.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
@ -84,7 +85,7 @@ public:
loopOperands.append(operands.begin(), operands.end()); loopOperands.append(operands.begin(), operands.end());
loopOperands.push_back(iters); loopOperands.push_back(iters);
rewriter.create<mlir::BranchOp>(loc, conditionalBlock, loopOperands); rewriter.create<mlir::cf::BranchOp>(loc, conditionalBlock, loopOperands);
// Last loop block // Last loop block
auto *terminator = lastBlock->getTerminator(); auto *terminator = lastBlock->getTerminator();
@ -105,7 +106,7 @@ public:
: terminator->operand_begin(); : terminator->operand_begin();
loopCarried.append(begin, terminator->operand_end()); loopCarried.append(begin, terminator->operand_end());
loopCarried.push_back(itersMinusOne); loopCarried.push_back(itersMinusOne);
rewriter.create<mlir::BranchOp>(loc, conditionalBlock, loopCarried); rewriter.create<mlir::cf::BranchOp>(loc, conditionalBlock, loopCarried);
rewriter.eraseOp(terminator); rewriter.eraseOp(terminator);
// Conditional block // Conditional block
@ -114,9 +115,9 @@ public:
auto comparison = rewriter.create<mlir::arith::CmpIOp>( auto comparison = rewriter.create<mlir::arith::CmpIOp>(
loc, arith::CmpIPredicate::sgt, itersLeft, zero); loc, arith::CmpIPredicate::sgt, itersLeft, zero);
rewriter.create<mlir::CondBranchOp>(loc, comparison, firstBlock, rewriter.create<mlir::cf::CondBranchOp>(
llvm::ArrayRef<mlir::Value>(), endBlock, loc, comparison, firstBlock, llvm::ArrayRef<mlir::Value>(), endBlock,
llvm::ArrayRef<mlir::Value>()); llvm::ArrayRef<mlir::Value>());
// The result of the loop operation is the values of the condition block // The result of the loop operation is the values of the condition block
// arguments except the induction variable on the last iteration. // arguments except the induction variable on the last iteration.
@ -155,7 +156,7 @@ public:
} else { } else {
continueBlock = continueBlock =
rewriter.createBlock(remainingOpsBlock, ifOp.getResultTypes()); rewriter.createBlock(remainingOpsBlock, ifOp.getResultTypes());
rewriter.create<mlir::BranchOp>(loc, remainingOpsBlock); rewriter.create<mlir::cf::BranchOp>(loc, remainingOpsBlock);
} }
// Move blocks from the "then" region to the region containing 'fir.if', // Move blocks from the "then" region to the region containing 'fir.if',
@ -165,7 +166,8 @@ public:
auto *ifOpTerminator = ifOpRegion.back().getTerminator(); auto *ifOpTerminator = ifOpRegion.back().getTerminator();
auto ifOpTerminatorOperands = ifOpTerminator->getOperands(); auto ifOpTerminatorOperands = ifOpTerminator->getOperands();
rewriter.setInsertionPointToEnd(&ifOpRegion.back()); rewriter.setInsertionPointToEnd(&ifOpRegion.back());
rewriter.create<mlir::BranchOp>(loc, continueBlock, ifOpTerminatorOperands); rewriter.create<mlir::cf::BranchOp>(loc, continueBlock,
ifOpTerminatorOperands);
rewriter.eraseOp(ifOpTerminator); rewriter.eraseOp(ifOpTerminator);
rewriter.inlineRegionBefore(ifOpRegion, continueBlock); rewriter.inlineRegionBefore(ifOpRegion, continueBlock);
@ -179,14 +181,14 @@ public:
auto *otherwiseTerm = otherwiseRegion.back().getTerminator(); auto *otherwiseTerm = otherwiseRegion.back().getTerminator();
auto otherwiseTermOperands = otherwiseTerm->getOperands(); auto otherwiseTermOperands = otherwiseTerm->getOperands();
rewriter.setInsertionPointToEnd(&otherwiseRegion.back()); rewriter.setInsertionPointToEnd(&otherwiseRegion.back());
rewriter.create<mlir::BranchOp>(loc, continueBlock, rewriter.create<mlir::cf::BranchOp>(loc, continueBlock,
otherwiseTermOperands); otherwiseTermOperands);
rewriter.eraseOp(otherwiseTerm); rewriter.eraseOp(otherwiseTerm);
rewriter.inlineRegionBefore(otherwiseRegion, continueBlock); rewriter.inlineRegionBefore(otherwiseRegion, continueBlock);
} }
rewriter.setInsertionPointToEnd(condBlock); rewriter.setInsertionPointToEnd(condBlock);
rewriter.create<mlir::CondBranchOp>( rewriter.create<mlir::cf::CondBranchOp>(
loc, ifOp.condition(), ifOpBlock, llvm::ArrayRef<mlir::Value>(), loc, ifOp.condition(), ifOpBlock, llvm::ArrayRef<mlir::Value>(),
otherwiseBlock, llvm::ArrayRef<mlir::Value>()); otherwiseBlock, llvm::ArrayRef<mlir::Value>());
rewriter.replaceOp(ifOp, continueBlock->getArguments()); rewriter.replaceOp(ifOp, continueBlock->getArguments());
@ -241,7 +243,7 @@ public:
auto begin = whileOp.finalValue() ? std::next(terminator->operand_begin()) auto begin = whileOp.finalValue() ? std::next(terminator->operand_begin())
: terminator->operand_begin(); : terminator->operand_begin();
loopCarried.append(begin, terminator->operand_end()); loopCarried.append(begin, terminator->operand_end());
rewriter.create<mlir::BranchOp>(loc, conditionBlock, loopCarried); rewriter.create<mlir::cf::BranchOp>(loc, conditionBlock, loopCarried);
rewriter.eraseOp(terminator); rewriter.eraseOp(terminator);
// Compute loop bounds before branching to the condition. // Compute loop bounds before branching to the condition.
@ -256,7 +258,7 @@ public:
destOperands.push_back(lowerBound); destOperands.push_back(lowerBound);
auto iterOperands = whileOp.getIterOperands(); auto iterOperands = whileOp.getIterOperands();
destOperands.append(iterOperands.begin(), iterOperands.end()); destOperands.append(iterOperands.begin(), iterOperands.end());
rewriter.create<mlir::BranchOp>(loc, conditionBlock, destOperands); rewriter.create<mlir::cf::BranchOp>(loc, conditionBlock, destOperands);
// With the body block done, we can fill in the condition block. // With the body block done, we can fill in the condition block.
rewriter.setInsertionPointToEnd(conditionBlock); rewriter.setInsertionPointToEnd(conditionBlock);
@ -278,9 +280,9 @@ public:
// Remember to AND in the early-exit bool. // Remember to AND in the early-exit bool.
auto comparison = auto comparison =
rewriter.create<mlir::arith::AndIOp>(loc, iterateVar, cmp2); rewriter.create<mlir::arith::AndIOp>(loc, iterateVar, cmp2);
rewriter.create<mlir::CondBranchOp>(loc, comparison, firstBodyBlock, rewriter.create<mlir::cf::CondBranchOp>(
llvm::ArrayRef<mlir::Value>(), endBlock, loc, comparison, firstBodyBlock, llvm::ArrayRef<mlir::Value>(),
llvm::ArrayRef<mlir::Value>()); endBlock, llvm::ArrayRef<mlir::Value>());
// The result of the loop operation is the values of the condition block // The result of the loop operation is the values of the condition block
// arguments except the induction variable on the last iteration. // arguments except the induction variable on the last iteration.
auto args = whileOp.finalValue() auto args = whileOp.finalValue()
@ -300,8 +302,8 @@ public:
patterns.insert<CfgLoopConv, CfgIfConv, CfgIterWhileConv>( patterns.insert<CfgLoopConv, CfgIfConv, CfgIterWhileConv>(
context, forceLoopToExecuteOnce); context, forceLoopToExecuteOnce);
mlir::ConversionTarget target(*context); mlir::ConversionTarget target(*context);
target.addLegalDialect<mlir::AffineDialect, FIROpsDialect, target.addLegalDialect<mlir::AffineDialect, mlir::cf::ControlFlowDialect,
mlir::StandardOpsDialect>(); FIROpsDialect, mlir::StandardOpsDialect>();
// apply the patterns // apply the patterns
target.addIllegalOp<ResultOp, DoLoopOp, IfOp, IterWhileOp>(); target.addIllegalOp<ResultOp, DoLoopOp, IfOp, IterWhileOp>();

View File

@ -10,10 +10,10 @@ func @select_case_charachter(%arg0: !fir.char<2, 10>, %arg1: !fir.char<2, 10>, %
unit, ^bb3] unit, ^bb3]
^bb1: ^bb1:
%c1_i32 = arith.constant 1 : i32 %c1_i32 = arith.constant 1 : i32
br ^bb3 cf.br ^bb3
^bb2: ^bb2:
%c2_i32 = arith.constant 2 : i32 %c2_i32 = arith.constant 2 : i32
br ^bb3 cf.br ^bb3
^bb3: ^bb3:
return return
} }

View File

@ -1175,23 +1175,23 @@ func @select_case_integer(%arg0: !fir.ref<i32>) -> i32 {
^bb1: // pred: ^bb0 ^bb1: // pred: ^bb0
%c1_i32_0 = arith.constant 1 : i32 %c1_i32_0 = arith.constant 1 : i32
fir.store %c1_i32_0 to %arg0 : !fir.ref<i32> fir.store %c1_i32_0 to %arg0 : !fir.ref<i32>
br ^bb6 cf.br ^bb6
^bb2: // pred: ^bb0 ^bb2: // pred: ^bb0
%c2_i32_1 = arith.constant 2 : i32 %c2_i32_1 = arith.constant 2 : i32
fir.store %c2_i32_1 to %arg0 : !fir.ref<i32> fir.store %c2_i32_1 to %arg0 : !fir.ref<i32>
br ^bb6 cf.br ^bb6
^bb3: // pred: ^bb0 ^bb3: // pred: ^bb0
%c0_i32 = arith.constant 0 : i32 %c0_i32 = arith.constant 0 : i32
fir.store %c0_i32 to %arg0 : !fir.ref<i32> fir.store %c0_i32 to %arg0 : !fir.ref<i32>
br ^bb6 cf.br ^bb6
^bb4: // pred: ^bb0 ^bb4: // pred: ^bb0
%c4_i32_2 = arith.constant 4 : i32 %c4_i32_2 = arith.constant 4 : i32
fir.store %c4_i32_2 to %arg0 : !fir.ref<i32> fir.store %c4_i32_2 to %arg0 : !fir.ref<i32>
br ^bb6 cf.br ^bb6
^bb5: // 3 preds: ^bb0, ^bb0, ^bb0 ^bb5: // 3 preds: ^bb0, ^bb0, ^bb0
%c7_i32_3 = arith.constant 7 : i32 %c7_i32_3 = arith.constant 7 : i32
fir.store %c7_i32_3 to %arg0 : !fir.ref<i32> fir.store %c7_i32_3 to %arg0 : !fir.ref<i32>
br ^bb6 cf.br ^bb6
^bb6: // 5 preds: ^bb1, ^bb2, ^bb3, ^bb4, ^bb5 ^bb6: // 5 preds: ^bb1, ^bb2, ^bb3, ^bb4, ^bb5
%3 = fir.load %arg0 : !fir.ref<i32> %3 = fir.load %arg0 : !fir.ref<i32>
return %3 : i32 return %3 : i32
@ -1275,10 +1275,10 @@ func @select_case_logical(%arg0: !fir.ref<!fir.logical<4>>) {
unit, ^bb3] unit, ^bb3]
^bb1: ^bb1:
%c1_i32 = arith.constant 1 : i32 %c1_i32 = arith.constant 1 : i32
br ^bb3 cf.br ^bb3
^bb2: ^bb2:
%c2_i32 = arith.constant 2 : i32 %c2_i32 = arith.constant 2 : i32
br ^bb3 cf.br ^bb3
^bb3: ^bb3:
return return
} }

View File

@ -9,10 +9,10 @@ func @load_store_chain_removal(%arg0: !fir.ref<!fir.array<60xi32>>, %arg1: !fir.
%c1 = arith.constant 1 : index %c1 = arith.constant 1 : index
%0 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFf1dcEi"} %0 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFf1dcEi"}
%1 = fir.alloca !fir.array<60xi32> {bindc_name = "t1", uniq_name = "_QFf1dcEt1"} %1 = fir.alloca !fir.array<60xi32> {bindc_name = "t1", uniq_name = "_QFf1dcEt1"}
br ^bb1(%c1, %c60 : index, index) cf.br ^bb1(%c1, %c60 : index, index)
^bb1(%2: index, %3: index): // 2 preds: ^bb0, ^bb2 ^bb1(%2: index, %3: index): // 2 preds: ^bb0, ^bb2
%4 = arith.cmpi sgt, %3, %c0 : index %4 = arith.cmpi sgt, %3, %c0 : index
cond_br %4, ^bb2, ^bb3 cf.cond_br %4, ^bb2, ^bb3
^bb2: // pred: ^bb1 ^bb2: // pred: ^bb1
%5 = fir.convert %2 : (index) -> i32 %5 = fir.convert %2 : (index) -> i32
fir.store %5 to %0 : !fir.ref<i32> fir.store %5 to %0 : !fir.ref<i32>
@ -26,14 +26,14 @@ func @load_store_chain_removal(%arg0: !fir.ref<!fir.array<60xi32>>, %arg1: !fir.
fir.store %11 to %12 : !fir.ref<i32> fir.store %11 to %12 : !fir.ref<i32>
%13 = arith.addi %2, %c1 : index %13 = arith.addi %2, %c1 : index
%14 = arith.subi %3, %c1 : index %14 = arith.subi %3, %c1 : index
br ^bb1(%13, %14 : index, index) cf.br ^bb1(%13, %14 : index, index)
^bb3: // pred: ^bb1 ^bb3: // pred: ^bb1
%15 = fir.convert %2 : (index) -> i32 %15 = fir.convert %2 : (index) -> i32
fir.store %15 to %0 : !fir.ref<i32> fir.store %15 to %0 : !fir.ref<i32>
br ^bb4(%c1, %c60 : index, index) cf.br ^bb4(%c1, %c60 : index, index)
^bb4(%16: index, %17: index): // 2 preds: ^bb3, ^bb5 ^bb4(%16: index, %17: index): // 2 preds: ^bb3, ^bb5
%18 = arith.cmpi sgt, %17, %c0 : index %18 = arith.cmpi sgt, %17, %c0 : index
cond_br %18, ^bb5, ^bb6 cf.cond_br %18, ^bb5, ^bb6
^bb5: // pred: ^bb4 ^bb5: // pred: ^bb4
%19 = fir.convert %16 : (index) -> i32 %19 = fir.convert %16 : (index) -> i32
fir.store %19 to %0 : !fir.ref<i32> fir.store %19 to %0 : !fir.ref<i32>
@ -49,7 +49,7 @@ func @load_store_chain_removal(%arg0: !fir.ref<!fir.array<60xi32>>, %arg1: !fir.
fir.store %27 to %28 : !fir.ref<i32> fir.store %27 to %28 : !fir.ref<i32>
%29 = arith.addi %16, %c1 : index %29 = arith.addi %16, %c1 : index
%30 = arith.subi %17, %c1 : index %30 = arith.subi %17, %c1 : index
br ^bb4(%29, %30 : index, index) cf.br ^bb4(%29, %30 : index, index)
^bb6: // pred: ^bb4 ^bb6: // pred: ^bb4
%31 = fir.convert %16 : (index) -> i32 %31 = fir.convert %16 : (index) -> i32
fir.store %31 to %0 : !fir.ref<i32> fir.store %31 to %0 : !fir.ref<i32>

View File

@ -13,7 +13,7 @@ FIRTransforms
FIRBuilder FIRBuilder
${dialect_libs} ${dialect_libs}
MLIRAffineToStandard MLIRAffineToStandard
MLIRSCFToStandard MLIRSCFToControlFlow
FortranCommon FortranCommon
FortranParser FortranParser
FortranEvaluate FortranEvaluate

View File

@ -38,7 +38,6 @@
#include "flang/Semantics/semantics.h" #include "flang/Semantics/semantics.h"
#include "flang/Semantics/unparse-with-symbols.h" #include "flang/Semantics/unparse-with-symbols.h"
#include "flang/Version.inc" #include "flang/Version.inc"
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
#include "mlir/IR/AsmState.h" #include "mlir/IR/AsmState.h"
#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/MLIRContext.h" #include "mlir/IR/MLIRContext.h"

View File

@ -18,7 +18,7 @@ target_link_libraries(fir-opt PRIVATE
MLIRTransforms MLIRTransforms
MLIRAffineToStandard MLIRAffineToStandard
MLIRAnalysis MLIRAnalysis
MLIRSCFToStandard MLIRSCFToControlFlow
MLIRParser MLIRParser
MLIRStandardToLLVM MLIRStandardToLLVM
MLIRSupport MLIRSupport

View File

@ -17,7 +17,7 @@ target_link_libraries(tco PRIVATE
MLIRTransforms MLIRTransforms
MLIRAffineToStandard MLIRAffineToStandard
MLIRAnalysis MLIRAnalysis
MLIRSCFToStandard MLIRSCFToControlFlow
MLIRParser MLIRParser
MLIRStandardToLLVM MLIRStandardToLLVM
MLIRSupport MLIRSupport

View File

@ -17,7 +17,6 @@
#include "flang/Optimizer/Support/InternalNames.h" #include "flang/Optimizer/Support/InternalNames.h"
#include "flang/Optimizer/Support/KindMapping.h" #include "flang/Optimizer/Support/KindMapping.h"
#include "flang/Optimizer/Transforms/Passes.h" #include "flang/Optimizer/Transforms/Passes.h"
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
#include "mlir/IR/AsmState.h" #include "mlir/IR/AsmState.h"
#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/MLIRContext.h" #include "mlir/IR/MLIRContext.h"

View File

@ -26,7 +26,7 @@ def setup_passes(mlir_module):
f"sparse-tensor-conversion," f"sparse-tensor-conversion,"
f"builtin.func" f"builtin.func"
f"(linalg-bufferize,convert-linalg-to-loops,convert-vector-to-scf)," f"(linalg-bufferize,convert-linalg-to-loops,convert-vector-to-scf),"
f"convert-scf-to-std," f"convert-scf-to-cf,"
f"func-bufferize," f"func-bufferize,"
f"arith-bufferize," f"arith-bufferize,"
f"builtin.func(tensor-bufferize,finalizing-bufferize)," f"builtin.func(tensor-bufferize,finalizing-bufferize),"

View File

@ -41,12 +41,12 @@ Example for breaking the invariant:
```mlir ```mlir
func @condBranch(%arg0: i1, %arg1: memref<2xf32>) { func @condBranch(%arg0: i1, %arg1: memref<2xf32>) {
%0 = memref.alloc() : memref<2xf32> %0 = memref.alloc() : memref<2xf32>
cond_br %arg0, ^bb1, ^bb2 cf.cond_br %arg0, ^bb1, ^bb2
^bb1: ^bb1:
br ^bb3() cf.br ^bb3()
^bb2: ^bb2:
partial_write(%0, %0) partial_write(%0, %0)
br ^bb3() cf.br ^bb3()
^bb3(): ^bb3():
test.copy(%0, %arg1) : (memref<2xf32>, memref<2xf32>) -> () test.copy(%0, %arg1) : (memref<2xf32>, memref<2xf32>) -> ()
return return
@ -74,13 +74,13 @@ untracked allocations are mixed:
func @mixedAllocation(%arg0: i1) { func @mixedAllocation(%arg0: i1) {
%0 = memref.alloca() : memref<2xf32> // aliases: %2 %0 = memref.alloca() : memref<2xf32> // aliases: %2
%1 = memref.alloc() : memref<2xf32> // aliases: %2 %1 = memref.alloc() : memref<2xf32> // aliases: %2
cond_br %arg0, ^bb1, ^bb2 cf.cond_br %arg0, ^bb1, ^bb2
^bb1: ^bb1:
use(%0) use(%0)
br ^bb3(%0 : memref<2xf32>) cf.br ^bb3(%0 : memref<2xf32>)
^bb2: ^bb2:
use(%1) use(%1)
br ^bb3(%1 : memref<2xf32>) cf.br ^bb3(%1 : memref<2xf32>)
^bb3(%2: memref<2xf32>): ^bb3(%2: memref<2xf32>):
... ...
} }
@ -129,13 +129,13 @@ BufferHoisting pass:
```mlir ```mlir
func @condBranch(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { func @condBranch(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
cond_br %arg0, ^bb1, ^bb2 cf.cond_br %arg0, ^bb1, ^bb2
^bb1: ^bb1:
br ^bb3(%arg1 : memref<2xf32>) cf.br ^bb3(%arg1 : memref<2xf32>)
^bb2: ^bb2:
%0 = memref.alloc() : memref<2xf32> // aliases: %1 %0 = memref.alloc() : memref<2xf32> // aliases: %1
use(%0) use(%0)
br ^bb3(%0 : memref<2xf32>) cf.br ^bb3(%0 : memref<2xf32>)
^bb3(%1: memref<2xf32>): // %1 could be %0 or %arg1 ^bb3(%1: memref<2xf32>): // %1 could be %0 or %arg1
test.copy(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> () test.copy(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
return return
@ -150,12 +150,12 @@ of code:
```mlir ```mlir
func @condBranch(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { func @condBranch(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
%0 = memref.alloc() : memref<2xf32> // moved to bb0 %0 = memref.alloc() : memref<2xf32> // moved to bb0
cond_br %arg0, ^bb1, ^bb2 cf.cond_br %arg0, ^bb1, ^bb2
^bb1: ^bb1:
br ^bb3(%arg1 : memref<2xf32>) cf.br ^bb3(%arg1 : memref<2xf32>)
^bb2: ^bb2:
use(%0) use(%0)
br ^bb3(%0 : memref<2xf32>) cf.br ^bb3(%0 : memref<2xf32>)
^bb3(%1: memref<2xf32>): ^bb3(%1: memref<2xf32>):
test.copy(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> () test.copy(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
return return
@ -175,14 +175,14 @@ func @condBranchDynamicType(
%arg1: memref<?xf32>, %arg1: memref<?xf32>,
%arg2: memref<?xf32>, %arg2: memref<?xf32>,
%arg3: index) { %arg3: index) {
cond_br %arg0, ^bb1, ^bb2(%arg3: index) cf.cond_br %arg0, ^bb1, ^bb2(%arg3: index)
^bb1: ^bb1:
br ^bb3(%arg1 : memref<?xf32>) cf.br ^bb3(%arg1 : memref<?xf32>)
^bb2(%0: index): ^bb2(%0: index):
%1 = memref.alloc(%0) : memref<?xf32> // cannot be moved upwards to the data %1 = memref.alloc(%0) : memref<?xf32> // cannot be moved upwards to the data
// dependency to %0 // dependency to %0
use(%1) use(%1)
br ^bb3(%1 : memref<?xf32>) cf.br ^bb3(%1 : memref<?xf32>)
^bb3(%2: memref<?xf32>): ^bb3(%2: memref<?xf32>):
test.copy(%2, %arg2) : (memref<?xf32>, memref<?xf32>) -> () test.copy(%2, %arg2) : (memref<?xf32>, memref<?xf32>) -> ()
return return
@ -201,14 +201,14 @@ allocations have already been placed:
```mlir ```mlir
func @branch(%arg0: i1) { func @branch(%arg0: i1) {
%0 = memref.alloc() : memref<2xf32> // aliases: %2 %0 = memref.alloc() : memref<2xf32> // aliases: %2
cond_br %arg0, ^bb1, ^bb2 cf.cond_br %arg0, ^bb1, ^bb2
^bb1: ^bb1:
%1 = memref.alloc() : memref<2xf32> // resides here for demonstration purposes %1 = memref.alloc() : memref<2xf32> // resides here for demonstration purposes
// aliases: %2 // aliases: %2
br ^bb3(%1 : memref<2xf32>) cf.br ^bb3(%1 : memref<2xf32>)
^bb2: ^bb2:
use(%0) use(%0)
br ^bb3(%0 : memref<2xf32>) cf.br ^bb3(%0 : memref<2xf32>)
^bb3(%2: memref<2xf32>): ^bb3(%2: memref<2xf32>):
return return
@ -233,16 +233,16 @@ result:
```mlir ```mlir
func @branch(%arg0: i1) { func @branch(%arg0: i1) {
%0 = memref.alloc() : memref<2xf32> %0 = memref.alloc() : memref<2xf32>
cond_br %arg0, ^bb1, ^bb2 cf.cond_br %arg0, ^bb1, ^bb2
^bb1: ^bb1:
%1 = memref.alloc() : memref<2xf32> %1 = memref.alloc() : memref<2xf32>
%3 = bufferization.clone %1 : (memref<2xf32>) -> (memref<2xf32>) %3 = bufferization.clone %1 : (memref<2xf32>) -> (memref<2xf32>)
memref.dealloc %1 : memref<2xf32> // %1 can be safely freed here memref.dealloc %1 : memref<2xf32> // %1 can be safely freed here
br ^bb3(%3 : memref<2xf32>) cf.br ^bb3(%3 : memref<2xf32>)
^bb2: ^bb2:
use(%0) use(%0)
%4 = bufferization.clone %0 : (memref<2xf32>) -> (memref<2xf32>) %4 = bufferization.clone %0 : (memref<2xf32>) -> (memref<2xf32>)
br ^bb3(%4 : memref<2xf32>) cf.br ^bb3(%4 : memref<2xf32>)
^bb3(%2: memref<2xf32>): ^bb3(%2: memref<2xf32>):
memref.dealloc %2 : memref<2xf32> // free temp buffer %2 memref.dealloc %2 : memref<2xf32> // free temp buffer %2
@ -273,23 +273,23 @@ func @condBranchDynamicTypeNested(
%arg1: memref<?xf32>, // aliases: %3, %4 %arg1: memref<?xf32>, // aliases: %3, %4
%arg2: memref<?xf32>, %arg2: memref<?xf32>,
%arg3: index) { %arg3: index) {
cond_br %arg0, ^bb1, ^bb2(%arg3: index) cf.cond_br %arg0, ^bb1, ^bb2(%arg3: index)
^bb1: ^bb1:
br ^bb6(%arg1 : memref<?xf32>) cf.br ^bb6(%arg1 : memref<?xf32>)
^bb2(%0: index): ^bb2(%0: index):
%1 = memref.alloc(%0) : memref<?xf32> // cannot be moved upwards due to the data %1 = memref.alloc(%0) : memref<?xf32> // cannot be moved upwards due to the data
// dependency to %0 // dependency to %0
// aliases: %2, %3, %4 // aliases: %2, %3, %4
use(%1) use(%1)
cond_br %arg0, ^bb3, ^bb4 cf.cond_br %arg0, ^bb3, ^bb4
^bb3: ^bb3:
br ^bb5(%1 : memref<?xf32>) cf.br ^bb5(%1 : memref<?xf32>)
^bb4: ^bb4:
br ^bb5(%1 : memref<?xf32>) cf.br ^bb5(%1 : memref<?xf32>)
^bb5(%2: memref<?xf32>): // non-crit. alias of %1, since %1 dominates %2 ^bb5(%2: memref<?xf32>): // non-crit. alias of %1, since %1 dominates %2
br ^bb6(%2 : memref<?xf32>) cf.br ^bb6(%2 : memref<?xf32>)
^bb6(%3: memref<?xf32>): // crit. alias of %arg1 and %2 (in other words %1) ^bb6(%3: memref<?xf32>): // crit. alias of %arg1 and %2 (in other words %1)
br ^bb7(%3 : memref<?xf32>) cf.br ^bb7(%3 : memref<?xf32>)
^bb7(%4: memref<?xf32>): // non-crit. alias of %3, since %3 dominates %4 ^bb7(%4: memref<?xf32>): // non-crit. alias of %3, since %3 dominates %4
test.copy(%4, %arg2) : (memref<?xf32>, memref<?xf32>) -> () test.copy(%4, %arg2) : (memref<?xf32>, memref<?xf32>) -> ()
return return
@ -306,25 +306,25 @@ func @condBranchDynamicTypeNested(
%arg1: memref<?xf32>, %arg1: memref<?xf32>,
%arg2: memref<?xf32>, %arg2: memref<?xf32>,
%arg3: index) { %arg3: index) {
cond_br %arg0, ^bb1, ^bb2(%arg3 : index) cf.cond_br %arg0, ^bb1, ^bb2(%arg3 : index)
^bb1: ^bb1:
// temp buffer required due to alias %3 // temp buffer required due to alias %3
%5 = bufferization.clone %arg1 : (memref<?xf32>) -> (memref<?xf32>) %5 = bufferization.clone %arg1 : (memref<?xf32>) -> (memref<?xf32>)
br ^bb6(%5 : memref<?xf32>) cf.br ^bb6(%5 : memref<?xf32>)
^bb2(%0: index): ^bb2(%0: index):
%1 = memref.alloc(%0) : memref<?xf32> %1 = memref.alloc(%0) : memref<?xf32>
use(%1) use(%1)
cond_br %arg0, ^bb3, ^bb4 cf.cond_br %arg0, ^bb3, ^bb4
^bb3: ^bb3:
br ^bb5(%1 : memref<?xf32>) cf.br ^bb5(%1 : memref<?xf32>)
^bb4: ^bb4:
br ^bb5(%1 : memref<?xf32>) cf.br ^bb5(%1 : memref<?xf32>)
^bb5(%2: memref<?xf32>): ^bb5(%2: memref<?xf32>):
%6 = bufferization.clone %1 : (memref<?xf32>) -> (memref<?xf32>) %6 = bufferization.clone %1 : (memref<?xf32>) -> (memref<?xf32>)
memref.dealloc %1 : memref<?xf32> memref.dealloc %1 : memref<?xf32>
br ^bb6(%6 : memref<?xf32>) cf.br ^bb6(%6 : memref<?xf32>)
^bb6(%3: memref<?xf32>): ^bb6(%3: memref<?xf32>):
br ^bb7(%3 : memref<?xf32>) cf.br ^bb7(%3 : memref<?xf32>)
^bb7(%4: memref<?xf32>): ^bb7(%4: memref<?xf32>):
test.copy(%4, %arg2) : (memref<?xf32>, memref<?xf32>) -> () test.copy(%4, %arg2) : (memref<?xf32>, memref<?xf32>) -> ()
memref.dealloc %3 : memref<?xf32> // free %3, since %4 is a non-crit. alias of %3 memref.dealloc %3 : memref<?xf32> // free %3, since %4 is a non-crit. alias of %3

View File

@ -295,7 +295,7 @@ A few examples are shown below:
```mlir ```mlir
// Expect an error on the same line. // Expect an error on the same line.
func @bad_branch() { func @bad_branch() {
br ^missing // expected-error {{reference to an undefined block}} cf.br ^missing // expected-error {{reference to an undefined block}}
} }
// Expect an error on an adjacent line. // Expect an error on an adjacent line.

View File

@ -114,8 +114,8 @@ struct MyTarget : public ConversionTarget {
/// All operations within the GPU dialect are illegal. /// All operations within the GPU dialect are illegal.
addIllegalDialect<GPUDialect>(); addIllegalDialect<GPUDialect>();
/// Mark `std.br` and `std.cond_br` as illegal. /// Mark `cf.br` and `cf.cond_br` as illegal.
addIllegalOp<BranchOp, CondBranchOp>(); addIllegalOp<cf::BranchOp, cf::CondBranchOp>();
} }
/// Implement the default legalization handler to handle operations marked as /// Implement the default legalization handler to handle operations marked as

View File

@ -23,10 +23,11 @@ argument `-declare-variables-at-top`.
Besides operations part of the EmitC dialect, the Cpp targets supports Besides operations part of the EmitC dialect, the Cpp targets supports
translating the following operations: translating the following operations:
* 'cf' Dialect
* `cf.br`
* `cf.cond_br`
* 'std' Dialect * 'std' Dialect
* `std.br`
* `std.call` * `std.call`
* `std.cond_br`
* `std.constant` * `std.constant`
* `std.return` * `std.return`
* 'scf' Dialect * 'scf' Dialect

View File

@ -391,21 +391,21 @@ arguments:
```mlir ```mlir
func @simple(i64, i1) -> i64 { func @simple(i64, i1) -> i64 {
^bb0(%a: i64, %cond: i1): // Code dominated by ^bb0 may refer to %a ^bb0(%a: i64, %cond: i1): // Code dominated by ^bb0 may refer to %a
cond_br %cond, ^bb1, ^bb2 cf.cond_br %cond, ^bb1, ^bb2
^bb1: ^bb1:
br ^bb3(%a: i64) // Branch passes %a as the argument cf.br ^bb3(%a: i64) // Branch passes %a as the argument
^bb2: ^bb2:
%b = arith.addi %a, %a : i64 %b = arith.addi %a, %a : i64
br ^bb3(%b: i64) // Branch passes %b as the argument cf.br ^bb3(%b: i64) // Branch passes %b as the argument
// ^bb3 receives an argument, named %c, from predecessors // ^bb3 receives an argument, named %c, from predecessors
// and passes it on to bb4 along with %a. %a is referenced // and passes it on to bb4 along with %a. %a is referenced
// directly from its defining operation and is not passed through // directly from its defining operation and is not passed through
// an argument of ^bb3. // an argument of ^bb3.
^bb3(%c: i64): ^bb3(%c: i64):
br ^bb4(%c, %a : i64, i64) cf.br ^bb4(%c, %a : i64, i64)
^bb4(%d : i64, %e : i64): ^bb4(%d : i64, %e : i64):
%0 = arith.addi %d, %e : i64 %0 = arith.addi %d, %e : i64
@ -525,12 +525,12 @@ Example:
```mlir ```mlir
func @accelerator_compute(i64, i1) -> i64 { // An SSACFG region func @accelerator_compute(i64, i1) -> i64 { // An SSACFG region
^bb0(%a: i64, %cond: i1): // Code dominated by ^bb0 may refer to %a ^bb0(%a: i64, %cond: i1): // Code dominated by ^bb0 may refer to %a
cond_br %cond, ^bb1, ^bb2 cf.cond_br %cond, ^bb1, ^bb2
^bb1: ^bb1:
// This def for %value does not dominate ^bb2 // This def for %value does not dominate ^bb2
%value = "op.convert"(%a) : (i64) -> i64 %value = "op.convert"(%a) : (i64) -> i64
br ^bb3(%a: i64) // Branch passes %a as the argument cf.br ^bb3(%a: i64) // Branch passes %a as the argument
^bb2: ^bb2:
accelerator.launch() { // An SSACFG region accelerator.launch() { // An SSACFG region

View File

@ -356,24 +356,24 @@ Example output is shown below:
``` ```
//===-------------------------------------------===// //===-------------------------------------------===//
Processing operation : 'std.cond_br'(0x60f000001120) { Processing operation : 'cf.cond_br'(0x60f000001120) {
"std.cond_br"(%arg0)[^bb2, ^bb2] {operand_segment_sizes = dense<[1, 0, 0]> : vector<3xi32>} : (i1) -> () "cf.cond_br"(%arg0)[^bb2, ^bb2] {operand_segment_sizes = dense<[1, 0, 0]> : vector<3xi32>} : (i1) -> ()
* Pattern SimplifyConstCondBranchPred : 'std.cond_br -> ()' { * Pattern SimplifyConstCondBranchPred : 'cf.cond_br -> ()' {
} -> failure : pattern failed to match } -> failure : pattern failed to match
* Pattern SimplifyCondBranchIdenticalSuccessors : 'std.cond_br -> ()' { * Pattern SimplifyCondBranchIdenticalSuccessors : 'cf.cond_br -> ()' {
** Insert : 'std.br'(0x60b000003690) ** Insert : 'cf.br'(0x60b000003690)
** Replace : 'std.cond_br'(0x60f000001120) ** Replace : 'cf.cond_br'(0x60f000001120)
} -> success : pattern applied successfully } -> success : pattern applied successfully
} -> success : pattern matched } -> success : pattern matched
//===-------------------------------------------===// //===-------------------------------------------===//
``` ```
This output is describing the processing of a `std.cond_br` operation. We first This output is describing the processing of a `cf.cond_br` operation. We first
try to apply the `SimplifyConstCondBranchPred`, which fails. From there, another try to apply the `SimplifyConstCondBranchPred`, which fails. From there, another
pattern (`SimplifyCondBranchIdenticalSuccessors`) is applied that matches the pattern (`SimplifyCondBranchIdenticalSuccessors`) is applied that matches the
`std.cond_br` and replaces it with a `std.br`. `cf.cond_br` and replaces it with a `cf.br`.
## Debugging ## Debugging

View File

@ -560,24 +560,24 @@ func @search(%A: memref<?x?xi32>, %S: <?xi32>, %key : i32) {
func @search_body(%A: memref<?x?xi32>, %S: memref<?xi32>, %key: i32, %i : i32) { func @search_body(%A: memref<?x?xi32>, %S: memref<?xi32>, %key: i32, %i : i32) {
%nj = memref.dim %A, 1 : memref<?x?xi32> %nj = memref.dim %A, 1 : memref<?x?xi32>
br ^bb1(0) cf.br ^bb1(0)
^bb1(%j: i32) ^bb1(%j: i32)
%p1 = arith.cmpi "lt", %j, %nj : i32 %p1 = arith.cmpi "lt", %j, %nj : i32
cond_br %p1, ^bb2, ^bb5 cf.cond_br %p1, ^bb2, ^bb5
^bb2: ^bb2:
%v = affine.load %A[%i, %j] : memref<?x?xi32> %v = affine.load %A[%i, %j] : memref<?x?xi32>
%p2 = arith.cmpi "eq", %v, %key : i32 %p2 = arith.cmpi "eq", %v, %key : i32
cond_br %p2, ^bb3(%j), ^bb4 cf.cond_br %p2, ^bb3(%j), ^bb4
^bb3(%j: i32) ^bb3(%j: i32)
affine.store %j, %S[%i] : memref<?xi32> affine.store %j, %S[%i] : memref<?xi32>
br ^bb5 cf.br ^bb5
^bb4: ^bb4:
%jinc = arith.addi %j, 1 : i32 %jinc = arith.addi %j, 1 : i32
br ^bb1(%jinc) cf.br ^bb1(%jinc)
^bb5: ^bb5:
return return

View File

@ -94,10 +94,11 @@ multiple stages by relying on
```c++ ```c++
mlir::RewritePatternSet patterns(&getContext()); mlir::RewritePatternSet patterns(&getContext());
mlir::populateAffineToStdConversionPatterns(patterns, &getContext()); mlir::populateAffineToStdConversionPatterns(patterns, &getContext());
mlir::populateLoopToStdConversionPatterns(patterns, &getContext()); mlir::cf::populateSCFToControlFlowConversionPatterns(patterns, &getContext());
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter, mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
patterns); patterns);
mlir::populateStdToLLVMConversionPatterns(typeConverter, patterns); mlir::populateStdToLLVMConversionPatterns(typeConverter, patterns);
mlir::cf::populateControlFlowToLLVMConversionPatterns(patterns, &getContext());
// The only remaining operation, to lower from the `toy` dialect, is the // The only remaining operation, to lower from the `toy` dialect, is the
// PrintOp. // PrintOp.
@ -207,7 +208,7 @@ define void @main() {
%109 = memref.load double, double* %108 %109 = memref.load double, double* %108
%110 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double %109) %110 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double %109)
%111 = add i64 %100, 1 %111 = add i64 %100, 1
br label %99 cf.br label %99
... ...

View File

@ -361,7 +361,7 @@
</tspan></tspan><tspan </tspan></tspan><tspan
x="73.476562" x="73.476562"
y="88.293896"><tspan y="88.293896"><tspan
style="font-size:5.64444px">br bb3(%0)</tspan></tspan></text> style="font-size:5.64444px">cf.br bb3(%0)</tspan></tspan></text>
<text <text
xml:space="preserve" xml:space="preserve"
id="text1894" id="text1894"

Before

Width:  |  Height:  |  Size: 15 KiB

After

Width:  |  Height:  |  Size: 15 KiB

View File

@ -362,7 +362,7 @@
</tspan></tspan><tspan </tspan></tspan><tspan
x="73.476562" x="73.476562"
y="88.293896"><tspan y="88.293896"><tspan
style="font-size:5.64444px">br bb3(%0)</tspan></tspan></text> style="font-size:5.64444px">cf.br bb3(%0)</tspan></tspan></text>
<text <text
xml:space="preserve" xml:space="preserve"
id="text1894" id="text1894"

Before

Width:  |  Height:  |  Size: 15 KiB

After

Width:  |  Height:  |  Size: 15 KiB

View File

@ -26,10 +26,11 @@
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h" #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/IR/AffineOps.h"
@ -200,10 +201,11 @@ void ToyToLLVMLoweringPass::runOnOperation() {
// set of legal ones. // set of legal ones.
RewritePatternSet patterns(&getContext()); RewritePatternSet patterns(&getContext());
populateAffineToStdConversionPatterns(patterns); populateAffineToStdConversionPatterns(patterns);
populateLoopToStdConversionPatterns(patterns); populateSCFToControlFlowConversionPatterns(patterns);
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter, mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
patterns); patterns);
populateMemRefToLLVMConversionPatterns(typeConverter, patterns); populateMemRefToLLVMConversionPatterns(typeConverter, patterns);
cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns);
populateStdToLLVMConversionPatterns(typeConverter, patterns); populateStdToLLVMConversionPatterns(typeConverter, patterns);
// The only remaining operation to lower from the `toy` dialect, is the // The only remaining operation to lower from the `toy` dialect, is the

View File

@ -26,10 +26,11 @@
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h" #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/IR/AffineOps.h"
@ -200,10 +201,11 @@ void ToyToLLVMLoweringPass::runOnOperation() {
// set of legal ones. // set of legal ones.
RewritePatternSet patterns(&getContext()); RewritePatternSet patterns(&getContext());
populateAffineToStdConversionPatterns(patterns); populateAffineToStdConversionPatterns(patterns);
populateLoopToStdConversionPatterns(patterns); populateSCFToControlFlowConversionPatterns(patterns);
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter, mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
patterns); patterns);
populateMemRefToLLVMConversionPatterns(typeConverter, patterns); populateMemRefToLLVMConversionPatterns(typeConverter, patterns);
cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns);
populateStdToLLVMConversionPatterns(typeConverter, patterns); populateStdToLLVMConversionPatterns(typeConverter, patterns);
// The only remaining operation to lower from the `toy` dialect, is the // The only remaining operation to lower from the `toy` dialect, is the

View File

@ -0,0 +1,35 @@
//===- ControlFlowToLLVM.h - ControlFlow to LLVM -----------*- C++ ------*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Define conversions from the ControlFlow dialect to the LLVM IR dialect.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_CONVERSION_CONTROLFLOWTOLLVM_CONTROLFLOWTOLLVM_H
#define MLIR_CONVERSION_CONTROLFLOWTOLLVM_CONTROLFLOWTOLLVM_H
#include <memory>
namespace mlir {
class LLVMTypeConverter;
class RewritePatternSet;
class Pass;
namespace cf {
/// Collect the patterns to convert from the ControlFlow dialect to LLVM. The
/// conversion patterns capture the LLVMTypeConverter by reference meaning the
/// references have to remain alive during the entire pattern lifetime.
void populateControlFlowToLLVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns);
/// Creates a pass to convert the ControlFlow dialect into the LLVMIR dialect.
std::unique_ptr<Pass> createConvertControlFlowToLLVMPass();
} // namespace cf
} // namespace mlir
#endif // MLIR_CONVERSION_CONTROLFLOWTOLLVM_CONTROLFLOWTOLLVM_H

View File

@ -0,0 +1,28 @@
//===- ControlFlowToSPIRV.h - CF to SPIR-V Patterns --------*- C++ ------*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Provides patterns to convert ControlFlow dialect to SPIR-V dialect.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_CONVERSION_CONTROLFLOWTOSPIRV_CONTROLFLOWTOSPIRV_H
#define MLIR_CONVERSION_CONTROLFLOWTOSPIRV_CONTROLFLOWTOSPIRV_H
namespace mlir {
class RewritePatternSet;
class SPIRVTypeConverter;
namespace cf {
/// Appends to a pattern list additional patterns for translating ControlFLow
/// ops to SPIR-V ops.
void populateControlFlowToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns);
} // namespace cf
} // namespace mlir
#endif // MLIR_CONVERSION_CONTROLFLOWTOSPIRV_CONTROLFLOWTOSPIRV_H

View File

@ -17,6 +17,8 @@
#include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h" #include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h"
#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
#include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h" #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
#include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h"
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h" #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h" #include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
@ -35,10 +37,10 @@
#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h" #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
#include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h" #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h" #include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h"
#include "mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h" #include "mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h"
#include "mlir/Conversion/SCFToSPIRV/SCFToSPIRVPass.h" #include "mlir/Conversion/SCFToSPIRV/SCFToSPIRVPass.h"
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
#include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.h" #include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.h"
#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"

View File

@ -181,6 +181,28 @@ def ConvertComplexToStandard : Pass<"convert-complex-to-standard", "FuncOp"> {
let dependentDialects = ["math::MathDialect"]; let dependentDialects = ["math::MathDialect"];
} }
//===----------------------------------------------------------------------===//
// ControlFlowToLLVM
//===----------------------------------------------------------------------===//
def ConvertControlFlowToLLVM : Pass<"convert-cf-to-llvm", "ModuleOp"> {
let summary = "Convert ControlFlow operations to the LLVM dialect";
let description = [{
Convert ControlFlow operations into LLVM IR dialect operations.
If other operations are present and their results are required by the LLVM
IR dialect operations, the pass will fail. Any LLVM IR operations or types
already present in the IR will be kept as is.
}];
let constructor = "mlir::cf::createConvertControlFlowToLLVMPass()";
let dependentDialects = ["LLVM::LLVMDialect"];
let options = [
Option<"indexBitwidth", "index-bitwidth", "unsigned",
/*default=kDeriveIndexBitwidthFromDataLayout*/"0",
"Bitwidth of the index type, 0 to use size of machine word">,
];
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// GPUCommon // GPUCommon
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -460,6 +482,17 @@ def ReconcileUnrealizedCasts : Pass<"reconcile-unrealized-casts"> {
let constructor = "mlir::createReconcileUnrealizedCastsPass()"; let constructor = "mlir::createReconcileUnrealizedCastsPass()";
} }
//===----------------------------------------------------------------------===//
// SCFToControlFlow
//===----------------------------------------------------------------------===//
def SCFToControlFlow : Pass<"convert-scf-to-cf"> {
let summary = "Convert SCF dialect to ControlFlow dialect, replacing structured"
" control flow with a CFG";
let constructor = "mlir::createConvertSCFToCFPass()";
let dependentDialects = ["cf::ControlFlowDialect"];
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// SCFToOpenMP // SCFToOpenMP
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -488,17 +521,6 @@ def SCFToSPIRV : Pass<"convert-scf-to-spirv", "ModuleOp"> {
let dependentDialects = ["spirv::SPIRVDialect"]; let dependentDialects = ["spirv::SPIRVDialect"];
} }
//===----------------------------------------------------------------------===//
// SCFToStandard
//===----------------------------------------------------------------------===//
def SCFToStandard : Pass<"convert-scf-to-std"> {
let summary = "Convert SCF dialect to Standard dialect, replacing structured"
" control flow with a CFG";
let constructor = "mlir::createLowerToCFGPass()";
let dependentDialects = ["StandardOpsDialect"];
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// SCFToGPU // SCFToGPU
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -547,7 +569,7 @@ def ConvertShapeConstraints: Pass<"convert-shape-constraints", "FuncOp"> {
computation lowering. computation lowering.
}]; }];
let constructor = "mlir::createConvertShapeConstraintsPass()"; let constructor = "mlir::createConvertShapeConstraintsPass()";
let dependentDialects = ["StandardOpsDialect", "scf::SCFDialect"]; let dependentDialects = ["cf::ControlFlowDialect", "scf::SCFDialect"];
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -0,0 +1,28 @@
//===- ConvertSCFToControlFlow.h - Pass entrypoint --------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_CONVERSION_SCFTOCONTROLFLOW_SCFTOCONTROLFLOW_H_
#define MLIR_CONVERSION_SCFTOCONTROLFLOW_SCFTOCONTROLFLOW_H_
#include <memory>
namespace mlir {
class Pass;
class RewritePatternSet;
/// Collect a set of patterns to convert SCF operations to CFG branch-based
/// operations within the ControlFlow dialect.
void populateSCFToControlFlowConversionPatterns(RewritePatternSet &patterns);
/// Creates a pass to convert SCF operations to CFG branch-based operation in
/// the ControlFlow dialect.
std::unique_ptr<Pass> createConvertSCFToCFPass();
} // namespace mlir
#endif // MLIR_CONVERSION_SCFTOCONTROLFLOW_SCFTOCONTROLFLOW_H_

View File

@ -1,31 +0,0 @@
//===- ConvertSCFToStandard.h - Pass entrypoint -----------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_CONVERSION_SCFTOSTANDARD_SCFTOSTANDARD_H_
#define MLIR_CONVERSION_SCFTOSTANDARD_SCFTOSTANDARD_H_
#include <memory>
#include <vector>
namespace mlir {
struct LogicalResult;
class Pass;
class RewritePatternSet;
/// Collect a set of patterns to lower from scf.for, scf.if, and
/// loop.terminator to CFG operations within the Standard dialect, in particular
/// convert structured control flow into CFG branch-based control flow.
void populateLoopToStdConversionPatterns(RewritePatternSet &patterns);
/// Creates a pass to convert scf.for, scf.if and loop.terminator ops to CFG.
std::unique_ptr<Pass> createLowerToCFGPass();
} // namespace mlir
#endif // MLIR_CONVERSION_SCFTOSTANDARD_SCFTOSTANDARD_H_

View File

@ -26,9 +26,9 @@ def BufferDeallocation : Pass<"buffer-deallocation", "FuncOp"> {
#map0 = affine_map<(d0) -> (d0)> #map0 = affine_map<(d0) -> (d0)>
module { module {
func @condBranch(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { func @condBranch(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
cond_br %arg0, ^bb1, ^bb2 cf.cond_br %arg0, ^bb1, ^bb2
^bb1: ^bb1:
br ^bb3(%arg1 : memref<2xf32>) cf.br ^bb3(%arg1 : memref<2xf32>)
^bb2: ^bb2:
%0 = memref.alloc() : memref<2xf32> %0 = memref.alloc() : memref<2xf32>
linalg.generic { linalg.generic {
@ -40,7 +40,7 @@ def BufferDeallocation : Pass<"buffer-deallocation", "FuncOp"> {
%tmp1 = exp %gen1_arg0 : f32 %tmp1 = exp %gen1_arg0 : f32
linalg.yield %tmp1 : f32 linalg.yield %tmp1 : f32
}: memref<2xf32>, memref<2xf32> }: memref<2xf32>, memref<2xf32>
br ^bb3(%0 : memref<2xf32>) cf.br ^bb3(%0 : memref<2xf32>)
^bb3(%1: memref<2xf32>): ^bb3(%1: memref<2xf32>):
"memref.copy"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> () "memref.copy"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
return return
@ -55,11 +55,11 @@ def BufferDeallocation : Pass<"buffer-deallocation", "FuncOp"> {
#map0 = affine_map<(d0) -> (d0)> #map0 = affine_map<(d0) -> (d0)>
module { module {
func @condBranch(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { func @condBranch(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
cond_br %arg0, ^bb1, ^bb2 cf.cond_br %arg0, ^bb1, ^bb2
^bb1: // pred: ^bb0 ^bb1: // pred: ^bb0
%0 = memref.alloc() : memref<2xf32> %0 = memref.alloc() : memref<2xf32>
memref.copy(%arg1, %0) : memref<2xf32>, memref<2xf32> memref.copy(%arg1, %0) : memref<2xf32>, memref<2xf32>
br ^bb3(%0 : memref<2xf32>) cf.br ^bb3(%0 : memref<2xf32>)
^bb2: // pred: ^bb0 ^bb2: // pred: ^bb0
%1 = memref.alloc() : memref<2xf32> %1 = memref.alloc() : memref<2xf32>
linalg.generic { linalg.generic {
@ -74,7 +74,7 @@ def BufferDeallocation : Pass<"buffer-deallocation", "FuncOp"> {
%2 = memref.alloc() : memref<2xf32> %2 = memref.alloc() : memref<2xf32>
memref.copy(%1, %2) : memref<2xf32>, memref<2xf32> memref.copy(%1, %2) : memref<2xf32>, memref<2xf32>
dealloc %1 : memref<2xf32> dealloc %1 : memref<2xf32>
br ^bb3(%2 : memref<2xf32>) cf.br ^bb3(%2 : memref<2xf32>)
^bb3(%3: memref<2xf32>): // 2 preds: ^bb1, ^bb2 ^bb3(%3: memref<2xf32>): // 2 preds: ^bb1, ^bb2
memref.copy(%3, %arg2) : memref<2xf32>, memref<2xf32> memref.copy(%3, %arg2) : memref<2xf32>, memref<2xf32>
dealloc %3 : memref<2xf32> dealloc %3 : memref<2xf32>

View File

@ -6,6 +6,7 @@ add_subdirectory(ArmSVE)
add_subdirectory(AMX) add_subdirectory(AMX)
add_subdirectory(Bufferization) add_subdirectory(Bufferization)
add_subdirectory(Complex) add_subdirectory(Complex)
add_subdirectory(ControlFlow)
add_subdirectory(DLTI) add_subdirectory(DLTI)
add_subdirectory(EmitC) add_subdirectory(EmitC)
add_subdirectory(GPU) add_subdirectory(GPU)

View File

@ -0,0 +1 @@
add_subdirectory(IR)

View File

@ -0,0 +1,2 @@
add_mlir_dialect(ControlFlowOps cf ControlFlowOps)
add_mlir_doc(ControlFlowOps ControlFlowDialect Dialects/ -gen-dialect-doc)

View File

@ -0,0 +1,21 @@
//===- ControlFlow.h - ControlFlow Dialect ----------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the ControlFlow dialect.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_CONTROLFLOW_IR_CONTROLFLOW_H
#define MLIR_DIALECT_CONTROLFLOW_IR_CONTROLFLOW_H
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/IR/Dialect.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOpsDialect.h.inc"
#endif // MLIR_DIALECT_CONTROLFLOW_IR_CONTROLFLOW_H

View File

@ -0,0 +1,30 @@
//===- ControlFlowOps.h - ControlFlow Operations ----------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the operations of the ControlFlow dialect.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_CONTROLFLOW_IR_CONTROLFLOWOPS_H
#define MLIR_DIALECT_CONTROLFLOW_IR_CONTROLFLOWOPS_H
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
namespace mlir {
class PatternRewriter;
} // namespace mlir
#define GET_OP_CLASSES
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h.inc"
#endif // MLIR_DIALECT_CONTROLFLOW_IR_CONTROLFLOWOPS_H

View File

@ -0,0 +1,313 @@
//===- ControlFlowOps.td - ControlFlow operations ----------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains definitions for the operations within the ControlFlow
// dialect.
//
//===----------------------------------------------------------------------===//
#ifndef STANDARD_OPS
#define STANDARD_OPS
include "mlir/IR/OpAsmInterface.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
def ControlFlow_Dialect : Dialect {
let name = "cf";
let cppNamespace = "::mlir::cf";
let dependentDialects = ["arith::ArithmeticDialect"];
let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed;
let description = [{
This dialect contains low-level, i.e. non-region based, control flow
constructs. These constructs generally represent control flow directly
on SSA blocks of a control flow graph.
}];
}
class CF_Op<string mnemonic, list<Trait> traits = []> :
Op<ControlFlow_Dialect, mnemonic, traits>;
//===----------------------------------------------------------------------===//
// AssertOp
//===----------------------------------------------------------------------===//
def AssertOp : CF_Op<"assert"> {
let summary = "Assert operation with message attribute";
let description = [{
Assert operation with single boolean operand and an error message attribute.
If the argument is `true` this operation has no effect. Otherwise, the
program execution will abort. The provided error message may be used by a
runtime to propagate the error to the user.
Example:
```mlir
assert %b, "Expected ... to be true"
```
}];
let arguments = (ins I1:$arg, StrAttr:$msg);
let assemblyFormat = "$arg `,` $msg attr-dict";
let hasCanonicalizeMethod = 1;
}
//===----------------------------------------------------------------------===//
// BranchOp
//===----------------------------------------------------------------------===//
def BranchOp : CF_Op<"br", [
DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
NoSideEffect, Terminator
]> {
let summary = "branch operation";
let description = [{
The `cf.br` operation represents a direct branch operation to a given
block. The operands of this operation are forwarded to the successor block,
and the number and type of the operands must match the arguments of the
target block.
Example:
```mlir
^bb2:
%2 = call @someFn()
cf.br ^bb3(%2 : tensor<*xf32>)
^bb3(%3: tensor<*xf32>):
```
}];
let arguments = (ins Variadic<AnyType>:$destOperands);
let successors = (successor AnySuccessor:$dest);
let builders = [
OpBuilder<(ins "Block *":$dest,
CArg<"ValueRange", "{}">:$destOperands), [{
$_state.addSuccessors(dest);
$_state.addOperands(destOperands);
}]>];
let extraClassDeclaration = [{
void setDest(Block *block);
/// Erase the operand at 'index' from the operand list.
void eraseOperand(unsigned index);
}];
let hasCanonicalizeMethod = 1;
let assemblyFormat = [{
$dest (`(` $destOperands^ `:` type($destOperands) `)`)? attr-dict
}];
}
//===----------------------------------------------------------------------===//
// CondBranchOp
//===----------------------------------------------------------------------===//
def CondBranchOp : CF_Op<"cond_br",
[AttrSizedOperandSegments,
DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
NoSideEffect, Terminator]> {
let summary = "conditional branch operation";
let description = [{
The `cond_br` terminator operation represents a conditional branch on a
boolean (1-bit integer) value. If the bit is set, then the first destination
is jumped to; if it is false, the second destination is chosen. The count
and types of operands must align with the arguments in the corresponding
target blocks.
The MLIR conditional branch operation is not allowed to target the entry
block for a region. The two destinations of the conditional branch operation
are allowed to be the same.
The following example illustrates a function with a conditional branch
operation that targets the same block.
Example:
```mlir
func @select(%a: i32, %b: i32, %flag: i1) -> i32 {
// Both targets are the same, operands differ
cond_br %flag, ^bb1(%a : i32), ^bb1(%b : i32)
^bb1(%x : i32) :
return %x : i32
}
```
}];
let arguments = (ins I1:$condition,
Variadic<AnyType>:$trueDestOperands,
Variadic<AnyType>:$falseDestOperands);
let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest);
let builders = [
OpBuilder<(ins "Value":$condition, "Block *":$trueDest,
"ValueRange":$trueOperands, "Block *":$falseDest,
"ValueRange":$falseOperands), [{
build($_builder, $_state, condition, trueOperands, falseOperands, trueDest,
falseDest);
}]>,
OpBuilder<(ins "Value":$condition, "Block *":$trueDest,
"Block *":$falseDest, CArg<"ValueRange", "{}">:$falseOperands), [{
build($_builder, $_state, condition, trueDest, ValueRange(), falseDest,
falseOperands);
}]>];
let extraClassDeclaration = [{
// These are the indices into the dests list.
enum { trueIndex = 0, falseIndex = 1 };
// Accessors for operands to the 'true' destination.
Value getTrueOperand(unsigned idx) {
assert(idx < getNumTrueOperands());
return getOperand(getTrueDestOperandIndex() + idx);
}
void setTrueOperand(unsigned idx, Value value) {
assert(idx < getNumTrueOperands());
setOperand(getTrueDestOperandIndex() + idx, value);
}
unsigned getNumTrueOperands() { return getTrueOperands().size(); }
/// Erase the operand at 'index' from the true operand list.
void eraseTrueOperand(unsigned index) {
getTrueDestOperandsMutable().erase(index);
}
// Accessors for operands to the 'false' destination.
Value getFalseOperand(unsigned idx) {
assert(idx < getNumFalseOperands());
return getOperand(getFalseDestOperandIndex() + idx);
}
void setFalseOperand(unsigned idx, Value value) {
assert(idx < getNumFalseOperands());
setOperand(getFalseDestOperandIndex() + idx, value);
}
operand_range getTrueOperands() { return getTrueDestOperands(); }
operand_range getFalseOperands() { return getFalseDestOperands(); }
unsigned getNumFalseOperands() { return getFalseOperands().size(); }
/// Erase the operand at 'index' from the false operand list.
void eraseFalseOperand(unsigned index) {
getFalseDestOperandsMutable().erase(index);
}
private:
/// Get the index of the first true destination operand.
unsigned getTrueDestOperandIndex() { return 1; }
/// Get the index of the first false destination operand.
unsigned getFalseDestOperandIndex() {
return getTrueDestOperandIndex() + getNumTrueOperands();
}
}];
let hasCanonicalizer = 1;
let assemblyFormat = [{
$condition `,`
$trueDest (`(` $trueDestOperands^ `:` type($trueDestOperands) `)`)? `,`
$falseDest (`(` $falseDestOperands^ `:` type($falseDestOperands) `)`)?
attr-dict
}];
}
//===----------------------------------------------------------------------===//
// SwitchOp
//===----------------------------------------------------------------------===//
def SwitchOp : CF_Op<"switch",
[AttrSizedOperandSegments,
DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
NoSideEffect, Terminator]> {
let summary = "switch operation";
let description = [{
The `switch` terminator operation represents a switch on a signless integer
value. If the flag matches one of the specified cases, then the
corresponding destination is jumped to. If the flag does not match any of
the cases, the default destination is jumped to. The count and types of
operands must align with the arguments in the corresponding target blocks.
Example:
```mlir
switch %flag : i32, [
default: ^bb1(%a : i32),
42: ^bb1(%b : i32),
43: ^bb3(%c : i32)
]
```
}];
let arguments = (ins
AnyInteger:$flag,
Variadic<AnyType>:$defaultOperands,
VariadicOfVariadic<AnyType, "case_operand_segments">:$caseOperands,
OptionalAttr<AnyIntElementsAttr>:$case_values,
I32ElementsAttr:$case_operand_segments
);
let successors = (successor
AnySuccessor:$defaultDestination,
VariadicSuccessor<AnySuccessor>:$caseDestinations
);
let builders = [
OpBuilder<(ins "Value":$flag,
"Block *":$defaultDestination,
"ValueRange":$defaultOperands,
CArg<"ArrayRef<APInt>", "{}">:$caseValues,
CArg<"BlockRange", "{}">:$caseDestinations,
CArg<"ArrayRef<ValueRange>", "{}">:$caseOperands)>,
OpBuilder<(ins "Value":$flag,
"Block *":$defaultDestination,
"ValueRange":$defaultOperands,
CArg<"ArrayRef<int32_t>", "{}">:$caseValues,
CArg<"BlockRange", "{}">:$caseDestinations,
CArg<"ArrayRef<ValueRange>", "{}">:$caseOperands)>,
OpBuilder<(ins "Value":$flag,
"Block *":$defaultDestination,
"ValueRange":$defaultOperands,
CArg<"DenseIntElementsAttr", "{}">:$caseValues,
CArg<"BlockRange", "{}">:$caseDestinations,
CArg<"ArrayRef<ValueRange>", "{}">:$caseOperands)>
];
let assemblyFormat = [{
$flag `:` type($flag) `,` `[` `\n`
custom<SwitchOpCases>(ref(type($flag)),$defaultDestination,
$defaultOperands,
type($defaultOperands),
$case_values,
$caseDestinations,
$caseOperands,
type($caseOperands))
`]`
attr-dict
}];
let extraClassDeclaration = [{
/// Return the operands for the case destination block at the given index.
OperandRange getCaseOperands(unsigned index) {
return getCaseOperands()[index];
}
/// Return a mutable range of operands for the case destination block at the
/// given index.
MutableOperandRange getCaseOperandsMutable(unsigned index) {
return getCaseOperandsMutable()[index];
}
}];
let hasCanonicalizer = 1;
let hasVerifier = 1;
}
#endif // STANDARD_OPS

View File

@ -84,15 +84,15 @@ def ExecuteRegionOp : SCF_Op<"execute_region"> {
affine.for %i = 0 to 100 { affine.for %i = 0 to 100 {
"foo"() : () -> () "foo"() : () -> ()
%v = scf.execute_region -> i64 { %v = scf.execute_region -> i64 {
cond_br %cond, ^bb1, ^bb2 cf.cond_br %cond, ^bb1, ^bb2
^bb1: ^bb1:
%c1 = arith.constant 1 : i64 %c1 = arith.constant 1 : i64
br ^bb3(%c1 : i64) cf.br ^bb3(%c1 : i64)
^bb2: ^bb2:
%c2 = arith.constant 2 : i64 %c2 = arith.constant 2 : i64
br ^bb3(%c2 : i64) cf.br ^bb3(%c2 : i64)
^bb3(%x : i64): ^bb3(%x : i64):
scf.yield %x : i64 scf.yield %x : i64

View File

@ -14,7 +14,7 @@
#ifndef MLIR_DIALECT_STANDARDOPS_IR_OPS_H #ifndef MLIR_DIALECT_STANDARDOPS_IR_OPS_H
#define MLIR_DIALECT_STANDARDOPS_IR_OPS_H #define MLIR_DIALECT_STANDARDOPS_IR_OPS_H
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h" #include "mlir/IR/Dialect.h"
@ -24,7 +24,6 @@
#include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/VectorInterfaces.h"
// Pull in all enum type definitions and utility function declarations. // Pull in all enum type definitions and utility function declarations.
#include "mlir/Dialect/StandardOps/IR/OpsEnums.h.inc" #include "mlir/Dialect/StandardOps/IR/OpsEnums.h.inc"

View File

@ -20,12 +20,11 @@ include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/VectorInterfaces.td"
def StandardOps_Dialect : Dialect { def StandardOps_Dialect : Dialect {
let name = "std"; let name = "std";
let cppNamespace = "::mlir"; let cppNamespace = "::mlir";
let dependentDialects = ["arith::ArithmeticDialect"]; let dependentDialects = ["cf::ControlFlowDialect"];
let hasConstantMaterializer = 1; let hasConstantMaterializer = 1;
let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed; let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed;
} }
@ -42,78 +41,6 @@ class Std_Op<string mnemonic, list<Trait> traits = []> :
let parser = [{ return ::parse$cppClass(parser, result); }]; let parser = [{ return ::parse$cppClass(parser, result); }];
} }
//===----------------------------------------------------------------------===//
// AssertOp
//===----------------------------------------------------------------------===//
def AssertOp : Std_Op<"assert"> {
let summary = "Assert operation with message attribute";
let description = [{
Assert operation with single boolean operand and an error message attribute.
If the argument is `true` this operation has no effect. Otherwise, the
program execution will abort. The provided error message may be used by a
runtime to propagate the error to the user.
Example:
```mlir
assert %b, "Expected ... to be true"
```
}];
let arguments = (ins I1:$arg, StrAttr:$msg);
let assemblyFormat = "$arg `,` $msg attr-dict";
let hasCanonicalizeMethod = 1;
}
//===----------------------------------------------------------------------===//
// BranchOp
//===----------------------------------------------------------------------===//
def BranchOp : Std_Op<"br",
[DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
NoSideEffect, Terminator]> {
let summary = "branch operation";
let description = [{
The `br` operation represents a branch operation in a function.
The operation takes variable number of operands and produces no results.
The operand number and types for each successor must match the arguments of
the block successor.
Example:
```mlir
^bb2:
%2 = call @someFn()
br ^bb3(%2 : tensor<*xf32>)
^bb3(%3: tensor<*xf32>):
```
}];
let arguments = (ins Variadic<AnyType>:$destOperands);
let successors = (successor AnySuccessor:$dest);
let builders = [
OpBuilder<(ins "Block *":$dest,
CArg<"ValueRange", "{}">:$destOperands), [{
$_state.addSuccessors(dest);
$_state.addOperands(destOperands);
}]>];
let extraClassDeclaration = [{
void setDest(Block *block);
/// Erase the operand at 'index' from the operand list.
void eraseOperand(unsigned index);
}];
let hasCanonicalizeMethod = 1;
let assemblyFormat = [{
$dest (`(` $destOperands^ `:` type($destOperands) `)`)? attr-dict
}];
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// CallOp // CallOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -246,121 +173,6 @@ def CallIndirectOp : Std_Op<"call_indirect", [
"$callee `(` $callee_operands `)` attr-dict `:` type($callee)"; "$callee `(` $callee_operands `)` attr-dict `:` type($callee)";
} }
//===----------------------------------------------------------------------===//
// CondBranchOp
//===----------------------------------------------------------------------===//
def CondBranchOp : Std_Op<"cond_br",
[AttrSizedOperandSegments,
DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
NoSideEffect, Terminator]> {
let summary = "conditional branch operation";
let description = [{
The `cond_br` terminator operation represents a conditional branch on a
boolean (1-bit integer) value. If the bit is set, then the first destination
is jumped to; if it is false, the second destination is chosen. The count
and types of operands must align with the arguments in the corresponding
target blocks.
The MLIR conditional branch operation is not allowed to target the entry
block for a region. The two destinations of the conditional branch operation
are allowed to be the same.
The following example illustrates a function with a conditional branch
operation that targets the same block.
Example:
```mlir
func @select(%a: i32, %b: i32, %flag: i1) -> i32 {
// Both targets are the same, operands differ
cond_br %flag, ^bb1(%a : i32), ^bb1(%b : i32)
^bb1(%x : i32) :
return %x : i32
}
```
}];
let arguments = (ins I1:$condition,
Variadic<AnyType>:$trueDestOperands,
Variadic<AnyType>:$falseDestOperands);
let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest);
let builders = [
OpBuilder<(ins "Value":$condition, "Block *":$trueDest,
"ValueRange":$trueOperands, "Block *":$falseDest,
"ValueRange":$falseOperands), [{
build($_builder, $_state, condition, trueOperands, falseOperands, trueDest,
falseDest);
}]>,
OpBuilder<(ins "Value":$condition, "Block *":$trueDest,
"Block *":$falseDest, CArg<"ValueRange", "{}">:$falseOperands), [{
build($_builder, $_state, condition, trueDest, ValueRange(), falseDest,
falseOperands);
}]>];
let extraClassDeclaration = [{
// These are the indices into the dests list.
enum { trueIndex = 0, falseIndex = 1 };
// Accessors for operands to the 'true' destination.
Value getTrueOperand(unsigned idx) {
assert(idx < getNumTrueOperands());
return getOperand(getTrueDestOperandIndex() + idx);
}
void setTrueOperand(unsigned idx, Value value) {
assert(idx < getNumTrueOperands());
setOperand(getTrueDestOperandIndex() + idx, value);
}
unsigned getNumTrueOperands() { return getTrueOperands().size(); }
/// Erase the operand at 'index' from the true operand list.
void eraseTrueOperand(unsigned index) {
getTrueDestOperandsMutable().erase(index);
}
// Accessors for operands to the 'false' destination.
Value getFalseOperand(unsigned idx) {
assert(idx < getNumFalseOperands());
return getOperand(getFalseDestOperandIndex() + idx);
}
void setFalseOperand(unsigned idx, Value value) {
assert(idx < getNumFalseOperands());
setOperand(getFalseDestOperandIndex() + idx, value);
}
operand_range getTrueOperands() { return getTrueDestOperands(); }
operand_range getFalseOperands() { return getFalseDestOperands(); }
unsigned getNumFalseOperands() { return getFalseOperands().size(); }
/// Erase the operand at 'index' from the false operand list.
void eraseFalseOperand(unsigned index) {
getFalseDestOperandsMutable().erase(index);
}
private:
/// Get the index of the first true destination operand.
unsigned getTrueDestOperandIndex() { return 1; }
/// Get the index of the first false destination operand.
unsigned getFalseDestOperandIndex() {
return getTrueDestOperandIndex() + getNumTrueOperands();
}
}];
let hasCanonicalizer = 1;
let assemblyFormat = [{
$condition `,`
$trueDest (`(` $trueDestOperands^ `:` type($trueDestOperands) `)`)? `,`
$falseDest (`(` $falseDestOperands^ `:` type($falseDestOperands) `)`)?
attr-dict
}];
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// ConstantOp // ConstantOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -443,93 +255,4 @@ def ReturnOp : Std_Op<"return", [NoSideEffect, HasParent<"FuncOp">,
let hasVerifier = 1; let hasVerifier = 1;
} }
//===----------------------------------------------------------------------===//
// SwitchOp
//===----------------------------------------------------------------------===//
def SwitchOp : Std_Op<"switch",
[AttrSizedOperandSegments,
DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
NoSideEffect, Terminator]> {
let summary = "switch operation";
let description = [{
The `switch` terminator operation represents a switch on a signless integer
value. If the flag matches one of the specified cases, then the
corresponding destination is jumped to. If the flag does not match any of
the cases, the default destination is jumped to. The count and types of
operands must align with the arguments in the corresponding target blocks.
Example:
```mlir
switch %flag : i32, [
default: ^bb1(%a : i32),
42: ^bb1(%b : i32),
43: ^bb3(%c : i32)
]
```
}];
let arguments = (ins
AnyInteger:$flag,
Variadic<AnyType>:$defaultOperands,
VariadicOfVariadic<AnyType, "case_operand_segments">:$caseOperands,
OptionalAttr<AnyIntElementsAttr>:$case_values,
I32ElementsAttr:$case_operand_segments
);
let successors = (successor
AnySuccessor:$defaultDestination,
VariadicSuccessor<AnySuccessor>:$caseDestinations
);
let builders = [
OpBuilder<(ins "Value":$flag,
"Block *":$defaultDestination,
"ValueRange":$defaultOperands,
CArg<"ArrayRef<APInt>", "{}">:$caseValues,
CArg<"BlockRange", "{}">:$caseDestinations,
CArg<"ArrayRef<ValueRange>", "{}">:$caseOperands)>,
OpBuilder<(ins "Value":$flag,
"Block *":$defaultDestination,
"ValueRange":$defaultOperands,
CArg<"ArrayRef<int32_t>", "{}">:$caseValues,
CArg<"BlockRange", "{}">:$caseDestinations,
CArg<"ArrayRef<ValueRange>", "{}">:$caseOperands)>,
OpBuilder<(ins "Value":$flag,
"Block *":$defaultDestination,
"ValueRange":$defaultOperands,
CArg<"DenseIntElementsAttr", "{}">:$caseValues,
CArg<"BlockRange", "{}">:$caseDestinations,
CArg<"ArrayRef<ValueRange>", "{}">:$caseOperands)>
];
let assemblyFormat = [{
$flag `:` type($flag) `,` `[` `\n`
custom<SwitchOpCases>(ref(type($flag)),$defaultDestination,
$defaultOperands,
type($defaultOperands),
$case_values,
$caseDestinations,
$caseOperands,
type($caseOperands))
`]`
attr-dict
}];
let extraClassDeclaration = [{
/// Return the operands for the case destination block at the given index.
OperandRange getCaseOperands(unsigned index) {
return getCaseOperands()[index];
}
/// Return a mutable range of operands for the case destination block at the
/// given index.
MutableOperandRange getCaseOperandsMutable(unsigned index) {
return getCaseOperandsMutable()[index];
}
}];
let hasCanonicalizer = 1;
let hasVerifier = 1;
}
#endif // STANDARD_OPS #endif // STANDARD_OPS

View File

@ -22,6 +22,7 @@
#include "mlir/Dialect/Async/IR/Async.h" #include "mlir/Dialect/Async/IR/Async.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/GPU/GPUDialect.h"
@ -61,6 +62,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
arm_neon::ArmNeonDialect, arm_neon::ArmNeonDialect,
async::AsyncDialect, async::AsyncDialect,
bufferization::BufferizationDialect, bufferization::BufferizationDialect,
cf::ControlFlowDialect,
complex::ComplexDialect, complex::ComplexDialect,
DLTIDialect, DLTIDialect,
emitc::EmitCDialect, emitc::EmitCDialect,

View File

@ -6,6 +6,8 @@ add_subdirectory(AsyncToLLVM)
add_subdirectory(BufferizationToMemRef) add_subdirectory(BufferizationToMemRef)
add_subdirectory(ComplexToLLVM) add_subdirectory(ComplexToLLVM)
add_subdirectory(ComplexToStandard) add_subdirectory(ComplexToStandard)
add_subdirectory(ControlFlowToLLVM)
add_subdirectory(ControlFlowToSPIRV)
add_subdirectory(GPUCommon) add_subdirectory(GPUCommon)
add_subdirectory(GPUToNVVM) add_subdirectory(GPUToNVVM)
add_subdirectory(GPUToROCDL) add_subdirectory(GPUToROCDL)
@ -25,10 +27,10 @@ add_subdirectory(OpenACCToSCF)
add_subdirectory(OpenMPToLLVM) add_subdirectory(OpenMPToLLVM)
add_subdirectory(PDLToPDLInterp) add_subdirectory(PDLToPDLInterp)
add_subdirectory(ReconcileUnrealizedCasts) add_subdirectory(ReconcileUnrealizedCasts)
add_subdirectory(SCFToControlFlow)
add_subdirectory(SCFToGPU) add_subdirectory(SCFToGPU)
add_subdirectory(SCFToOpenMP) add_subdirectory(SCFToOpenMP)
add_subdirectory(SCFToSPIRV) add_subdirectory(SCFToSPIRV)
add_subdirectory(SCFToStandard)
add_subdirectory(ShapeToStandard) add_subdirectory(ShapeToStandard)
add_subdirectory(SPIRVToLLVM) add_subdirectory(SPIRVToLLVM)
add_subdirectory(StandardToLLVM) add_subdirectory(StandardToLLVM)

View File

@ -0,0 +1,21 @@
add_mlir_conversion_library(MLIRControlFlowToLLVM
ControlFlowToLLVM.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ControlFlowToLLVM
DEPENDS
MLIRConversionPassIncGen
intrinsics_gen
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIRAnalysis
MLIRControlFlow
MLIRLLVMCommonConversion
MLIRLLVMIR
MLIRPass
MLIRTransformUtils
)

View File

@ -0,0 +1,148 @@
//===- ControlFlowToLLVM.cpp - ControlFlow to LLVM dialect conversion -----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a pass to convert MLIR standard and builtin dialects
// into the LLVM IR dialect.
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
#include "../PassDetail.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
#include <functional>
using namespace mlir;
#define PASS_NAME "convert-cf-to-llvm"
namespace {
/// Lower `std.assert`. The default lowering calls the `abort` function if the
/// assertion is violated and has no effect otherwise. The failure message is
/// ignored by the default lowering but should be propagated by any custom
/// lowering.
struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
using ConvertOpToLLVMPattern<cf::AssertOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
// Insert the `abort` declaration if necessary.
auto module = op->getParentOfType<ModuleOp>();
auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
if (!abortFunc) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {});
abortFunc = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(),
"abort", abortFuncTy);
}
// Split block at `assert` operation.
Block *opBlock = rewriter.getInsertionBlock();
auto opPosition = rewriter.getInsertionPoint();
Block *continuationBlock = rewriter.splitBlock(opBlock, opPosition);
// Generate IR to call `abort`.
Block *failureBlock = rewriter.createBlock(opBlock->getParent());
rewriter.create<LLVM::CallOp>(loc, abortFunc, llvm::None);
rewriter.create<LLVM::UnreachableOp>(loc);
// Generate assertion test.
rewriter.setInsertionPointToEnd(opBlock);
rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
op, adaptor.getArg(), continuationBlock, failureBlock);
return success();
}
};
// Base class for LLVM IR lowering terminator operations with successors.
template <typename SourceOp, typename TargetOp>
struct OneToOneLLVMTerminatorLowering
: public ConvertOpToLLVMPattern<SourceOp> {
using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
using Base = OneToOneLLVMTerminatorLowering<SourceOp, TargetOp>;
LogicalResult
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<TargetOp>(op, adaptor.getOperands(),
op->getSuccessors(), op->getAttrs());
return success();
}
};
// FIXME: this should be tablegen'ed as well.
struct BranchOpLowering
: public OneToOneLLVMTerminatorLowering<cf::BranchOp, LLVM::BrOp> {
using Base::Base;
};
struct CondBranchOpLowering
: public OneToOneLLVMTerminatorLowering<cf::CondBranchOp, LLVM::CondBrOp> {
using Base::Base;
};
struct SwitchOpLowering
: public OneToOneLLVMTerminatorLowering<cf::SwitchOp, LLVM::SwitchOp> {
using Base::Base;
};
} // namespace
void mlir::cf::populateControlFlowToLLVMConversionPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
// clang-format off
patterns.add<
AssertOpLowering,
BranchOpLowering,
CondBranchOpLowering,
SwitchOpLowering>(converter);
// clang-format on
}
//===----------------------------------------------------------------------===//
// Pass Definition
//===----------------------------------------------------------------------===//
namespace {
/// A pass converting MLIR operations into the LLVM IR dialect.
struct ConvertControlFlowToLLVM
: public ConvertControlFlowToLLVMBase<ConvertControlFlowToLLVM> {
ConvertControlFlowToLLVM() = default;
/// Run the dialect converter on the module.
void runOnOperation() override {
LLVMConversionTarget target(getContext());
RewritePatternSet patterns(&getContext());
LowerToLLVMOptions options(&getContext());
if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
options.overrideIndexBitwidth(indexBitwidth);
LLVMTypeConverter converter(&getContext(), options);
mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
}
};
} // namespace
std::unique_ptr<Pass> mlir::cf::createConvertControlFlowToLLVMPass() {
return std::make_unique<ConvertControlFlowToLLVM>();
}

View File

@ -0,0 +1,19 @@
add_mlir_conversion_library(MLIRControlFlowToSPIRV
ControlFlowToSPIRV.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SPIRV
${MLIR_MAIN_INCLUDE_DIR}/mlir/IR
DEPENDS
MLIRConversionPassIncGen
LINK_LIBS PUBLIC
MLIRIR
MLIRControlFlow
MLIRPass
MLIRSPIRV
MLIRSPIRVConversion
MLIRSupport
MLIRTransformUtils
)

View File

@ -0,0 +1,73 @@
//===- ControlFlowToSPIRV.cpp - ControlFlow to SPIR-V Patterns ------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements patterns to convert standard dialect to SPIR-V dialect.
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h"
#include "../SPIRVCommon/Pattern.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "cf-to-spirv-pattern"
using namespace mlir;
//===----------------------------------------------------------------------===//
// Operation conversion
//===----------------------------------------------------------------------===//
namespace {
/// Converts cf.br to spv.Branch.
struct BranchOpPattern final : public OpConversionPattern<cf::BranchOp> {
using OpConversionPattern<cf::BranchOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(cf::BranchOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<spirv::BranchOp>(op, op.getDest(),
adaptor.getDestOperands());
return success();
}
};
/// Converts cf.cond_br to spv.BranchConditional.
struct CondBranchOpPattern final
: public OpConversionPattern<cf::CondBranchOp> {
using OpConversionPattern<cf::CondBranchOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(cf::CondBranchOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<spirv::BranchConditionalOp>(
op, op.getCondition(), op.getTrueDest(), adaptor.getTrueDestOperands(),
op.getFalseDest(), adaptor.getFalseDestOperands());
return success();
}
};
} // namespace
//===----------------------------------------------------------------------===//
// Pattern population
//===----------------------------------------------------------------------===//
void mlir::cf::populateControlFlowToSPIRVPatterns(
SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
patterns.add<BranchOpPattern, CondBranchOpPattern>(typeConverter, context);
}

View File

@ -14,6 +14,7 @@
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h" #include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
@ -172,8 +173,8 @@ struct LowerGpuOpsToNVVMOpsPass
populateGpuRewritePatterns(patterns); populateGpuRewritePatterns(patterns);
(void)applyPatternsAndFoldGreedily(m, std::move(patterns)); (void)applyPatternsAndFoldGreedily(m, std::move(patterns));
mlir::arith::populateArithmeticToLLVMConversionPatterns(converter, arith::populateArithmeticToLLVMConversionPatterns(converter, llvmPatterns);
llvmPatterns); cf::populateControlFlowToLLVMConversionPatterns(converter, llvmPatterns);
populateStdToLLVMConversionPatterns(converter, llvmPatterns); populateStdToLLVMConversionPatterns(converter, llvmPatterns);
populateMemRefToLLVMConversionPatterns(converter, llvmPatterns); populateMemRefToLLVMConversionPatterns(converter, llvmPatterns);
populateGpuToNVVMConversionPatterns(converter, llvmPatterns); populateGpuToNVVMConversionPatterns(converter, llvmPatterns);

View File

@ -19,7 +19,7 @@ add_mlir_conversion_library(MLIRLinalgToLLVM
MLIRLLVMCommonConversion MLIRLLVMCommonConversion
MLIRLLVMIR MLIRLLVMIR
MLIRMemRefToLLVM MLIRMemRefToLLVM
MLIRSCFToStandard MLIRSCFToControlFlow
MLIRTransforms MLIRTransforms
MLIRVectorToLLVM MLIRVectorToLLVM
MLIRVectorToSCF MLIRVectorToSCF

View File

@ -14,7 +14,7 @@
#include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h" #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h" #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h"

View File

@ -433,7 +433,7 @@ struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
/// +---------------------------------+ /// +---------------------------------+
/// | <code before the AtomicRMWOp> | /// | <code before the AtomicRMWOp> |
/// | <compute initial %loaded> | /// | <compute initial %loaded> |
/// | br loop(%loaded) | /// | cf.br loop(%loaded) |
/// +---------------------------------+ /// +---------------------------------+
/// | /// |
/// -------| | /// -------| |
@ -444,7 +444,7 @@ struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
/// | | %pair = cmpxchg | /// | | %pair = cmpxchg |
/// | | %ok = %pair[0] | /// | | %ok = %pair[0] |
/// | | %new = %pair[1] | /// | | %new = %pair[1] |
/// | | cond_br %ok, end, loop(%new) | /// | | cf.cond_br %ok, end, loop(%new) |
/// | +--------------------------------+ /// | +--------------------------------+
/// | | | /// | | |
/// |----------- | /// |----------- |

View File

@ -10,6 +10,7 @@
#include "../PassDetail.h" #include "../PassDetail.h"
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
@ -66,7 +67,8 @@ void ConvertOpenMPToLLVMPass::runOnOperation() {
// Convert to OpenMP operations with LLVM IR dialect // Convert to OpenMP operations with LLVM IR dialect
RewritePatternSet patterns(&getContext()); RewritePatternSet patterns(&getContext());
LLVMTypeConverter converter(&getContext()); LLVMTypeConverter converter(&getContext());
mlir::arith::populateArithmeticToLLVMConversionPatterns(converter, patterns); arith::populateArithmeticToLLVMConversionPatterns(converter, patterns);
cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
populateMemRefToLLVMConversionPatterns(converter, patterns); populateMemRefToLLVMConversionPatterns(converter, patterns);
populateStdToLLVMConversionPatterns(converter, patterns); populateStdToLLVMConversionPatterns(converter, patterns);
populateOpenMPToLLVMConversionPatterns(converter, patterns); populateOpenMPToLLVMConversionPatterns(converter, patterns);

View File

@ -29,6 +29,10 @@ namespace arith {
class ArithmeticDialect; class ArithmeticDialect;
} // namespace arith } // namespace arith
namespace cf {
class ControlFlowDialect;
} // namespace cf
namespace complex { namespace complex {
class ComplexDialect; class ComplexDialect;
} // namespace complex } // namespace complex

View File

@ -1,8 +1,8 @@
add_mlir_conversion_library(MLIRSCFToStandard add_mlir_conversion_library(MLIRSCFToControlFlow
SCFToStandard.cpp SCFToControlFlow.cpp
ADDITIONAL_HEADER_DIRS ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/SCFToStandard ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/SCFToControlFlow
DEPENDS DEPENDS
MLIRConversionPassIncGen MLIRConversionPassIncGen
@ -12,6 +12,7 @@ add_mlir_conversion_library(MLIRSCFToStandard
LINK_LIBS PUBLIC LINK_LIBS PUBLIC
MLIRArithmetic MLIRArithmetic
MLIRControlFlow
MLIRSCF MLIRSCF
MLIRTransforms MLIRTransforms
) )

View File

@ -1,4 +1,4 @@
//===- SCFToStandard.cpp - ControlFlow to CFG conversion ------------------===// //===- SCFToControlFlow.cpp - SCF to CF conversion ------------------------===//
// //
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information. // See https://llvm.org/LICENSE.txt for license information.
@ -11,11 +11,11 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h" #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "../PassDetail.h" #include "../PassDetail.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinOps.h"
@ -29,7 +29,8 @@ using namespace mlir::scf;
namespace { namespace {
struct SCFToStandardPass : public SCFToStandardBase<SCFToStandardPass> { struct SCFToControlFlowPass
: public SCFToControlFlowBase<SCFToControlFlowPass> {
void runOnOperation() override; void runOnOperation() override;
}; };
@ -57,7 +58,7 @@ struct SCFToStandardPass : public SCFToStandardBase<SCFToStandardPass> {
// | <code before the ForOp> | // | <code before the ForOp> |
// | <definitions of %init...> | // | <definitions of %init...> |
// | <compute initial %iv value> | // | <compute initial %iv value> |
// | br cond(%iv, %init...) | // | cf.br cond(%iv, %init...) |
// +---------------------------------+ // +---------------------------------+
// | // |
// -------| | // -------| |
@ -65,7 +66,7 @@ struct SCFToStandardPass : public SCFToStandardBase<SCFToStandardPass> {
// | +--------------------------------+ // | +--------------------------------+
// | | cond(%iv, %init...): | // | | cond(%iv, %init...): |
// | | <compare %iv to upper bound> | // | | <compare %iv to upper bound> |
// | | cond_br %r, body, end | // | | cf.cond_br %r, body, end |
// | +--------------------------------+ // | +--------------------------------+
// | | | // | | |
// | | -------------| // | | -------------|
@ -83,7 +84,7 @@ struct SCFToStandardPass : public SCFToStandardBase<SCFToStandardPass> {
// | | <body contents> | | // | | <body contents> | |
// | | <operands of yield = %yields>| | // | | <operands of yield = %yields>| |
// | | %new_iv =<add step to %iv> | | // | | %new_iv =<add step to %iv> | |
// | | br cond(%new_iv, %yields) | | // | | cf.br cond(%new_iv, %yields) | |
// | +--------------------------------+ | // | +--------------------------------+ |
// | | | // | | |
// |----------- |-------------------- // |----------- |--------------------
@ -125,7 +126,7 @@ struct ForLowering : public OpRewritePattern<ForOp> {
// //
// +--------------------------------+ // +--------------------------------+
// | <code before the IfOp> | // | <code before the IfOp> |
// | cond_br %cond, %then, %else | // | cf.cond_br %cond, %then, %else |
// +--------------------------------+ // +--------------------------------+
// | | // | |
// | --------------| // | --------------|
@ -133,7 +134,7 @@ struct ForLowering : public OpRewritePattern<ForOp> {
// +--------------------------------+ | // +--------------------------------+ |
// | then: | | // | then: | |
// | <then contents> | | // | <then contents> | |
// | br continue | | // | cf.br continue | |
// +--------------------------------+ | // +--------------------------------+ |
// | | // | |
// |---------- |------------- // |---------- |-------------
@ -141,7 +142,7 @@ struct ForLowering : public OpRewritePattern<ForOp> {
// | +--------------------------------+ // | +--------------------------------+
// | | else: | // | | else: |
// | | <else contents> | // | | <else contents> |
// | | br continue | // | | cf.br continue |
// | +--------------------------------+ // | +--------------------------------+
// | | // | |
// ------| | // ------| |
@ -155,7 +156,7 @@ struct ForLowering : public OpRewritePattern<ForOp> {
// //
// +--------------------------------+ // +--------------------------------+
// | <code before the IfOp> | // | <code before the IfOp> |
// | cond_br %cond, %then, %else | // | cf.cond_br %cond, %then, %else |
// +--------------------------------+ // +--------------------------------+
// | | // | |
// | --------------| // | --------------|
@ -163,7 +164,7 @@ struct ForLowering : public OpRewritePattern<ForOp> {
// +--------------------------------+ | // +--------------------------------+ |
// | then: | | // | then: | |
// | <then contents> | | // | <then contents> | |
// | br dom(%args...) | | // | cf.br dom(%args...) | |
// +--------------------------------+ | // +--------------------------------+ |
// | | // | |
// |---------- |------------- // |---------- |-------------
@ -171,14 +172,14 @@ struct ForLowering : public OpRewritePattern<ForOp> {
// | +--------------------------------+ // | +--------------------------------+
// | | else: | // | | else: |
// | | <else contents> | // | | <else contents> |
// | | br dom(%args...) | // | | cf.br dom(%args...) |
// | +--------------------------------+ // | +--------------------------------+
// | | // | |
// ------| | // ------| |
// v v // v v
// +--------------------------------+ // +--------------------------------+
// | dom(%args...): | // | dom(%args...): |
// | br continue | // | cf.br continue |
// +--------------------------------+ // +--------------------------------+
// | // |
// v // v
@ -218,7 +219,7 @@ struct ParallelLowering : public OpRewritePattern<mlir::scf::ParallelOp> {
/// ///
/// +---------------------------------+ /// +---------------------------------+
/// | <code before the WhileOp> | /// | <code before the WhileOp> |
/// | br ^before(%operands...) | /// | cf.br ^before(%operands...) |
/// +---------------------------------+ /// +---------------------------------+
/// | /// |
/// -------| | /// -------| |
@ -233,7 +234,7 @@ struct ParallelLowering : public OpRewritePattern<mlir::scf::ParallelOp> {
/// | +--------------------------------+ /// | +--------------------------------+
/// | | ^before-last: /// | | ^before-last:
/// | | %cond = <compute condition> | /// | | %cond = <compute condition> |
/// | | cond_br %cond, | /// | | cf.cond_br %cond, |
/// | | ^after(%vals...), ^cont | /// | | ^after(%vals...), ^cont |
/// | +--------------------------------+ /// | +--------------------------------+
/// | | | /// | | |
@ -249,7 +250,7 @@ struct ParallelLowering : public OpRewritePattern<mlir::scf::ParallelOp> {
/// | +--------------------------------+ | /// | +--------------------------------+ |
/// | | ^after-last: | | /// | | ^after-last: | |
/// | | %yields... = <some payload> | | /// | | %yields... = <some payload> | |
/// | | br ^before(%yields...) | | /// | | cf.br ^before(%yields...) | |
/// | +--------------------------------+ | /// | +--------------------------------+ |
/// | | | /// | | |
/// |----------- |-------------------- /// |----------- |--------------------
@ -321,7 +322,7 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
SmallVector<Value, 8> loopCarried; SmallVector<Value, 8> loopCarried;
loopCarried.push_back(stepped); loopCarried.push_back(stepped);
loopCarried.append(terminator->operand_begin(), terminator->operand_end()); loopCarried.append(terminator->operand_begin(), terminator->operand_end());
rewriter.create<BranchOp>(loc, conditionBlock, loopCarried); rewriter.create<cf::BranchOp>(loc, conditionBlock, loopCarried);
rewriter.eraseOp(terminator); rewriter.eraseOp(terminator);
// Compute loop bounds before branching to the condition. // Compute loop bounds before branching to the condition.
@ -337,15 +338,16 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
destOperands.push_back(lowerBound); destOperands.push_back(lowerBound);
auto iterOperands = forOp.getIterOperands(); auto iterOperands = forOp.getIterOperands();
destOperands.append(iterOperands.begin(), iterOperands.end()); destOperands.append(iterOperands.begin(), iterOperands.end());
rewriter.create<BranchOp>(loc, conditionBlock, destOperands); rewriter.create<cf::BranchOp>(loc, conditionBlock, destOperands);
// With the body block done, we can fill in the condition block. // With the body block done, we can fill in the condition block.
rewriter.setInsertionPointToEnd(conditionBlock); rewriter.setInsertionPointToEnd(conditionBlock);
auto comparison = rewriter.create<arith::CmpIOp>( auto comparison = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, iv, upperBound); loc, arith::CmpIPredicate::slt, iv, upperBound);
rewriter.create<CondBranchOp>(loc, comparison, firstBodyBlock, rewriter.create<cf::CondBranchOp>(loc, comparison, firstBodyBlock,
ArrayRef<Value>(), endBlock, ArrayRef<Value>()); ArrayRef<Value>(), endBlock,
ArrayRef<Value>());
// The result of the loop operation is the values of the condition block // The result of the loop operation is the values of the condition block
// arguments except the induction variable on the last iteration. // arguments except the induction variable on the last iteration.
rewriter.replaceOp(forOp, conditionBlock->getArguments().drop_front()); rewriter.replaceOp(forOp, conditionBlock->getArguments().drop_front());
@ -369,7 +371,7 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
continueBlock = continueBlock =
rewriter.createBlock(remainingOpsBlock, ifOp.getResultTypes(), rewriter.createBlock(remainingOpsBlock, ifOp.getResultTypes(),
SmallVector<Location>(ifOp.getNumResults(), loc)); SmallVector<Location>(ifOp.getNumResults(), loc));
rewriter.create<BranchOp>(loc, remainingOpsBlock); rewriter.create<cf::BranchOp>(loc, remainingOpsBlock);
} }
// Move blocks from the "then" region to the region containing 'scf.if', // Move blocks from the "then" region to the region containing 'scf.if',
@ -379,7 +381,7 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
Operation *thenTerminator = thenRegion.back().getTerminator(); Operation *thenTerminator = thenRegion.back().getTerminator();
ValueRange thenTerminatorOperands = thenTerminator->getOperands(); ValueRange thenTerminatorOperands = thenTerminator->getOperands();
rewriter.setInsertionPointToEnd(&thenRegion.back()); rewriter.setInsertionPointToEnd(&thenRegion.back());
rewriter.create<BranchOp>(loc, continueBlock, thenTerminatorOperands); rewriter.create<cf::BranchOp>(loc, continueBlock, thenTerminatorOperands);
rewriter.eraseOp(thenTerminator); rewriter.eraseOp(thenTerminator);
rewriter.inlineRegionBefore(thenRegion, continueBlock); rewriter.inlineRegionBefore(thenRegion, continueBlock);
@ -393,15 +395,15 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
Operation *elseTerminator = elseRegion.back().getTerminator(); Operation *elseTerminator = elseRegion.back().getTerminator();
ValueRange elseTerminatorOperands = elseTerminator->getOperands(); ValueRange elseTerminatorOperands = elseTerminator->getOperands();
rewriter.setInsertionPointToEnd(&elseRegion.back()); rewriter.setInsertionPointToEnd(&elseRegion.back());
rewriter.create<BranchOp>(loc, continueBlock, elseTerminatorOperands); rewriter.create<cf::BranchOp>(loc, continueBlock, elseTerminatorOperands);
rewriter.eraseOp(elseTerminator); rewriter.eraseOp(elseTerminator);
rewriter.inlineRegionBefore(elseRegion, continueBlock); rewriter.inlineRegionBefore(elseRegion, continueBlock);
} }
rewriter.setInsertionPointToEnd(condBlock); rewriter.setInsertionPointToEnd(condBlock);
rewriter.create<CondBranchOp>(loc, ifOp.getCondition(), thenBlock, rewriter.create<cf::CondBranchOp>(loc, ifOp.getCondition(), thenBlock,
/*trueArgs=*/ArrayRef<Value>(), elseBlock, /*trueArgs=*/ArrayRef<Value>(), elseBlock,
/*falseArgs=*/ArrayRef<Value>()); /*falseArgs=*/ArrayRef<Value>());
// Ok, we're done! // Ok, we're done!
rewriter.replaceOp(ifOp, continueBlock->getArguments()); rewriter.replaceOp(ifOp, continueBlock->getArguments());
@ -419,13 +421,13 @@ ExecuteRegionLowering::matchAndRewrite(ExecuteRegionOp op,
auto &region = op.getRegion(); auto &region = op.getRegion();
rewriter.setInsertionPointToEnd(condBlock); rewriter.setInsertionPointToEnd(condBlock);
rewriter.create<BranchOp>(loc, &region.front()); rewriter.create<cf::BranchOp>(loc, &region.front());
for (Block &block : region) { for (Block &block : region) {
if (auto terminator = dyn_cast<scf::YieldOp>(block.getTerminator())) { if (auto terminator = dyn_cast<scf::YieldOp>(block.getTerminator())) {
ValueRange terminatorOperands = terminator->getOperands(); ValueRange terminatorOperands = terminator->getOperands();
rewriter.setInsertionPointToEnd(&block); rewriter.setInsertionPointToEnd(&block);
rewriter.create<BranchOp>(loc, remainingOpsBlock, terminatorOperands); rewriter.create<cf::BranchOp>(loc, remainingOpsBlock, terminatorOperands);
rewriter.eraseOp(terminator); rewriter.eraseOp(terminator);
} }
} }
@ -538,20 +540,21 @@ LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp,
// Branch to the "before" region. // Branch to the "before" region.
rewriter.setInsertionPointToEnd(currentBlock); rewriter.setInsertionPointToEnd(currentBlock);
rewriter.create<BranchOp>(loc, before, whileOp.getInits()); rewriter.create<cf::BranchOp>(loc, before, whileOp.getInits());
// Replace terminators with branches. Assuming bodies are SESE, which holds // Replace terminators with branches. Assuming bodies are SESE, which holds
// given only the patterns from this file, we only need to look at the last // given only the patterns from this file, we only need to look at the last
// block. This should be reconsidered if we allow break/continue in SCF. // block. This should be reconsidered if we allow break/continue in SCF.
rewriter.setInsertionPointToEnd(beforeLast); rewriter.setInsertionPointToEnd(beforeLast);
auto condOp = cast<ConditionOp>(beforeLast->getTerminator()); auto condOp = cast<ConditionOp>(beforeLast->getTerminator());
rewriter.replaceOpWithNewOp<CondBranchOp>(condOp, condOp.getCondition(), rewriter.replaceOpWithNewOp<cf::CondBranchOp>(condOp, condOp.getCondition(),
after, condOp.getArgs(), after, condOp.getArgs(),
continuation, ValueRange()); continuation, ValueRange());
rewriter.setInsertionPointToEnd(afterLast); rewriter.setInsertionPointToEnd(afterLast);
auto yieldOp = cast<scf::YieldOp>(afterLast->getTerminator()); auto yieldOp = cast<scf::YieldOp>(afterLast->getTerminator());
rewriter.replaceOpWithNewOp<BranchOp>(yieldOp, before, yieldOp.getResults()); rewriter.replaceOpWithNewOp<cf::BranchOp>(yieldOp, before,
yieldOp.getResults());
// Replace the op with values "yielded" from the "before" region, which are // Replace the op with values "yielded" from the "before" region, which are
// visible by dominance. // visible by dominance.
@ -593,14 +596,14 @@ DoWhileLowering::matchAndRewrite(WhileOp whileOp,
// Branch to the "before" region. // Branch to the "before" region.
rewriter.setInsertionPointToEnd(currentBlock); rewriter.setInsertionPointToEnd(currentBlock);
rewriter.create<BranchOp>(whileOp.getLoc(), before, whileOp.getInits()); rewriter.create<cf::BranchOp>(whileOp.getLoc(), before, whileOp.getInits());
// Loop around the "before" region based on condition. // Loop around the "before" region based on condition.
rewriter.setInsertionPointToEnd(beforeLast); rewriter.setInsertionPointToEnd(beforeLast);
auto condOp = cast<ConditionOp>(beforeLast->getTerminator()); auto condOp = cast<ConditionOp>(beforeLast->getTerminator());
rewriter.replaceOpWithNewOp<CondBranchOp>(condOp, condOp.getCondition(), rewriter.replaceOpWithNewOp<cf::CondBranchOp>(condOp, condOp.getCondition(),
before, condOp.getArgs(), before, condOp.getArgs(),
continuation, ValueRange()); continuation, ValueRange());
// Replace the op with values "yielded" from the "before" region, which are // Replace the op with values "yielded" from the "before" region, which are
// visible by dominance. // visible by dominance.
@ -609,17 +612,18 @@ DoWhileLowering::matchAndRewrite(WhileOp whileOp,
return success(); return success();
} }
void mlir::populateLoopToStdConversionPatterns(RewritePatternSet &patterns) { void mlir::populateSCFToControlFlowConversionPatterns(
RewritePatternSet &patterns) {
patterns.add<ForLowering, IfLowering, ParallelLowering, WhileLowering, patterns.add<ForLowering, IfLowering, ParallelLowering, WhileLowering,
ExecuteRegionLowering>(patterns.getContext()); ExecuteRegionLowering>(patterns.getContext());
patterns.add<DoWhileLowering>(patterns.getContext(), /*benefit=*/2); patterns.add<DoWhileLowering>(patterns.getContext(), /*benefit=*/2);
} }
void SCFToStandardPass::runOnOperation() { void SCFToControlFlowPass::runOnOperation() {
RewritePatternSet patterns(&getContext()); RewritePatternSet patterns(&getContext());
populateLoopToStdConversionPatterns(patterns); populateSCFToControlFlowConversionPatterns(patterns);
// Configure conversion to lower out scf.for, scf.if, scf.parallel and
// scf.while. Anything else is fine. // Configure conversion to lower out SCF operations.
ConversionTarget target(getContext()); ConversionTarget target(getContext());
target.addIllegalOp<scf::ForOp, scf::IfOp, scf::ParallelOp, scf::WhileOp, target.addIllegalOp<scf::ForOp, scf::IfOp, scf::ParallelOp, scf::WhileOp,
scf::ExecuteRegionOp>(); scf::ExecuteRegionOp>();
@ -629,6 +633,6 @@ void SCFToStandardPass::runOnOperation() {
signalPassFailure(); signalPassFailure();
} }
std::unique_ptr<Pass> mlir::createLowerToCFGPass() { std::unique_ptr<Pass> mlir::createConvertSCFToCFPass() {
return std::make_unique<SCFToStandardPass>(); return std::make_unique<SCFToControlFlowPass>();
} }

View File

@ -9,6 +9,7 @@
#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
#include "../PassDetail.h" #include "../PassDetail.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
@ -29,7 +30,7 @@ public:
using OpRewritePattern::OpRewritePattern; using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(shape::CstrRequireOp op, LogicalResult matchAndRewrite(shape::CstrRequireOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
rewriter.create<AssertOp>(op.getLoc(), op.getPred(), op.getMsgAttr()); rewriter.create<cf::AssertOp>(op.getLoc(), op.getPred(), op.getMsgAttr());
rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true); rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true);
return success(); return success();
} }

View File

@ -14,6 +14,7 @@ add_mlir_conversion_library(MLIRStandardToLLVM
LINK_LIBS PUBLIC LINK_LIBS PUBLIC
MLIRAnalysis MLIRAnalysis
MLIRArithmeticToLLVM MLIRArithmeticToLLVM
MLIRControlFlowToLLVM
MLIRDataLayoutInterfaces MLIRDataLayoutInterfaces
MLIRLLVMCommonConversion MLIRLLVMCommonConversion
MLIRLLVMIR MLIRLLVMIR

View File

@ -14,6 +14,7 @@
#include "../PassDetail.h" #include "../PassDetail.h"
#include "mlir/Analysis/DataLayoutAnalysis.h" #include "mlir/Analysis/DataLayoutAnalysis.h"
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/LLVMCommon/VectorPattern.h" #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
@ -387,48 +388,6 @@ struct BarePtrFuncOpConversion : public FuncOpConversionBase {
} }
}; };
/// Lower `std.assert`. The default lowering calls the `abort` function if the
/// assertion is violated and has no effect otherwise. The failure message is
/// ignored by the default lowering but should be propagated by any custom
/// lowering.
struct AssertOpLowering : public ConvertOpToLLVMPattern<AssertOp> {
using ConvertOpToLLVMPattern<AssertOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(AssertOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
// Insert the `abort` declaration if necessary.
auto module = op->getParentOfType<ModuleOp>();
auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
if (!abortFunc) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {});
abortFunc = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(),
"abort", abortFuncTy);
}
// Split block at `assert` operation.
Block *opBlock = rewriter.getInsertionBlock();
auto opPosition = rewriter.getInsertionPoint();
Block *continuationBlock = rewriter.splitBlock(opBlock, opPosition);
// Generate IR to call `abort`.
Block *failureBlock = rewriter.createBlock(opBlock->getParent());
rewriter.create<LLVM::CallOp>(loc, abortFunc, llvm::None);
rewriter.create<LLVM::UnreachableOp>(loc);
// Generate assertion test.
rewriter.setInsertionPointToEnd(opBlock);
rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
op, adaptor.getArg(), continuationBlock, failureBlock);
return success();
}
};
struct ConstantOpLowering : public ConvertOpToLLVMPattern<ConstantOp> { struct ConstantOpLowering : public ConvertOpToLLVMPattern<ConstantOp> {
using ConvertOpToLLVMPattern<ConstantOp>::ConvertOpToLLVMPattern; using ConvertOpToLLVMPattern<ConstantOp>::ConvertOpToLLVMPattern;
@ -550,22 +509,6 @@ struct UnrealizedConversionCastOpLowering
} }
}; };
// Base class for LLVM IR lowering terminator operations with successors.
template <typename SourceOp, typename TargetOp>
struct OneToOneLLVMTerminatorLowering
: public ConvertOpToLLVMPattern<SourceOp> {
using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
using Super = OneToOneLLVMTerminatorLowering<SourceOp, TargetOp>;
LogicalResult
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<TargetOp>(op, adaptor.getOperands(),
op->getSuccessors(), op->getAttrs());
return success();
}
};
// Special lowering pattern for `ReturnOps`. Unlike all other operations, // Special lowering pattern for `ReturnOps`. Unlike all other operations,
// `ReturnOp` interacts with the function signature and must have as many // `ReturnOp` interacts with the function signature and must have as many
// operands as the function has return values. Because in LLVM IR, functions // operands as the function has return values. Because in LLVM IR, functions
@ -633,21 +576,6 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<ReturnOp> {
return success(); return success();
} }
}; };
// FIXME: this should be tablegen'ed as well.
struct BranchOpLowering
: public OneToOneLLVMTerminatorLowering<BranchOp, LLVM::BrOp> {
using Super::Super;
};
struct CondBranchOpLowering
: public OneToOneLLVMTerminatorLowering<CondBranchOp, LLVM::CondBrOp> {
using Super::Super;
};
struct SwitchOpLowering
: public OneToOneLLVMTerminatorLowering<SwitchOp, LLVM::SwitchOp> {
using Super::Super;
};
} // namespace } // namespace
void mlir::populateStdToLLVMFuncOpConversionPattern( void mlir::populateStdToLLVMFuncOpConversionPattern(
@ -663,14 +591,10 @@ void mlir::populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter,
populateStdToLLVMFuncOpConversionPattern(converter, patterns); populateStdToLLVMFuncOpConversionPattern(converter, patterns);
// clang-format off // clang-format off
patterns.add< patterns.add<
AssertOpLowering,
BranchOpLowering,
CallIndirectOpLowering, CallIndirectOpLowering,
CallOpLowering, CallOpLowering,
CondBranchOpLowering,
ConstantOpLowering, ConstantOpLowering,
ReturnOpLowering, ReturnOpLowering>(converter);
SwitchOpLowering>(converter);
// clang-format on // clang-format on
} }
@ -721,6 +645,7 @@ struct LLVMLoweringPass : public ConvertStandardToLLVMBase<LLVMLoweringPass> {
RewritePatternSet patterns(&getContext()); RewritePatternSet patterns(&getContext());
populateStdToLLVMConversionPatterns(typeConverter, patterns); populateStdToLLVMConversionPatterns(typeConverter, patterns);
arith::populateArithmeticToLLVMConversionPatterns(typeConverter, patterns); arith::populateArithmeticToLLVMConversionPatterns(typeConverter, patterns);
cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns);
LLVMConversionTarget target(getContext()); LLVMConversionTarget target(getContext());
if (failed(applyPartialConversion(m, target, std::move(patterns)))) if (failed(applyPartialConversion(m, target, std::move(patterns))))

View File

@ -11,6 +11,7 @@ add_mlir_conversion_library(MLIRStandardToSPIRV
LINK_LIBS PUBLIC LINK_LIBS PUBLIC
MLIRArithmeticToSPIRV MLIRArithmeticToSPIRV
MLIRControlFlowToSPIRV
MLIRIR MLIRIR
MLIRMathToSPIRV MLIRMathToSPIRV
MLIRMemRef MLIRMemRef

View File

@ -46,24 +46,6 @@ public:
ConversionPatternRewriter &rewriter) const override; ConversionPatternRewriter &rewriter) const override;
}; };
/// Converts std.br to spv.Branch.
struct BranchOpPattern final : public OpConversionPattern<BranchOp> {
using OpConversionPattern<BranchOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(BranchOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
/// Converts std.cond_br to spv.BranchConditional.
struct CondBranchOpPattern final : public OpConversionPattern<CondBranchOp> {
using OpConversionPattern<CondBranchOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(CondBranchOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
/// Converts tensor.extract into loading using access chains from SPIR-V local /// Converts tensor.extract into loading using access chains from SPIR-V local
/// variables. /// variables.
class TensorExtractPattern final class TensorExtractPattern final
@ -146,31 +128,6 @@ ReturnOpPattern::matchAndRewrite(ReturnOp returnOp, OpAdaptor adaptor,
return success(); return success();
} }
//===----------------------------------------------------------------------===//
// BranchOpPattern
//===----------------------------------------------------------------------===//
LogicalResult
BranchOpPattern::matchAndRewrite(BranchOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
rewriter.replaceOpWithNewOp<spirv::BranchOp>(op, op.getDest(),
adaptor.getDestOperands());
return success();
}
//===----------------------------------------------------------------------===//
// CondBranchOpPattern
//===----------------------------------------------------------------------===//
LogicalResult CondBranchOpPattern::matchAndRewrite(
CondBranchOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
rewriter.replaceOpWithNewOp<spirv::BranchConditionalOp>(
op, op.getCondition(), op.getTrueDest(), adaptor.getTrueDestOperands(),
op.getFalseDest(), adaptor.getFalseDestOperands());
return success();
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Pattern population // Pattern population
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -189,8 +146,7 @@ void populateStandardToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
spirv::ElementwiseOpPattern<arith::MinSIOp, spirv::GLSLSMinOp>, spirv::ElementwiseOpPattern<arith::MinSIOp, spirv::GLSLSMinOp>,
spirv::ElementwiseOpPattern<arith::MinUIOp, spirv::GLSLUMinOp>, spirv::ElementwiseOpPattern<arith::MinUIOp, spirv::GLSLUMinOp>,
ReturnOpPattern, BranchOpPattern, CondBranchOpPattern>(typeConverter, ReturnOpPattern>(typeConverter, context);
context);
} }
void populateTensorToSPIRVPatterns(SPIRVTypeConverter &typeConverter, void populateTensorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,

View File

@ -13,6 +13,7 @@
#include "mlir/Conversion/StandardToSPIRV/StandardToSPIRVPass.h" #include "mlir/Conversion/StandardToSPIRV/StandardToSPIRVPass.h"
#include "../PassDetail.h" #include "../PassDetail.h"
#include "mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h" #include "mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h"
#include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h"
#include "mlir/Conversion/MathToSPIRV/MathToSPIRV.h" #include "mlir/Conversion/MathToSPIRV/MathToSPIRV.h"
#include "mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h" #include "mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
@ -40,9 +41,11 @@ void ConvertStandardToSPIRVPass::runOnOperation() {
options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes; options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes;
SPIRVTypeConverter typeConverter(targetAttr, options); SPIRVTypeConverter typeConverter(targetAttr, options);
// TODO ArithmeticToSPIRV cannot be applied separately to StandardToSPIRV // TODO ArithmeticToSPIRV/ControlFlowToSPIRV cannot be applied separately to
// StandardToSPIRV
RewritePatternSet patterns(context); RewritePatternSet patterns(context);
arith::populateArithmeticToSPIRVPatterns(typeConverter, patterns); arith::populateArithmeticToSPIRVPatterns(typeConverter, patterns);
cf::populateControlFlowToSPIRVPatterns(typeConverter, patterns);
populateMathToSPIRVPatterns(typeConverter, patterns); populateMathToSPIRVPatterns(typeConverter, patterns);
populateStandardToSPIRVPatterns(typeConverter, patterns); populateStandardToSPIRVPatterns(typeConverter, patterns);
populateTensorToSPIRVPatterns(typeConverter, /*byteCountThreshold=*/64, populateTensorToSPIRVPatterns(typeConverter, /*byteCountThreshold=*/64,

View File

@ -15,6 +15,7 @@
#include "mlir/Analysis/Liveness.h" #include "mlir/Analysis/Liveness.h"
#include "mlir/Dialect/Async/IR/Async.h" #include "mlir/Dialect/Async/IR/Async.h"
#include "mlir/Dialect/Async/Passes.h" #include "mlir/Dialect/Async/Passes.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
@ -169,11 +170,11 @@ private:
/// ///
/// ^entry: /// ^entry:
/// %token = async.runtime.create : !async.token /// %token = async.runtime.create : !async.token
/// cond_br %cond, ^bb1, ^bb2 /// cf.cond_br %cond, ^bb1, ^bb2
/// ^bb1: /// ^bb1:
/// async.runtime.await %token /// async.runtime.await %token
/// async.runtime.drop_ref %token /// async.runtime.drop_ref %token
/// br ^bb2 /// cf.br ^bb2
/// ^bb2: /// ^bb2:
/// return /// return
/// ///
@ -185,14 +186,14 @@ private:
/// ///
/// ^entry: /// ^entry:
/// %token = async.runtime.create : !async.token /// %token = async.runtime.create : !async.token
/// cond_br %cond, ^bb1, ^reference_counting /// cf.cond_br %cond, ^bb1, ^reference_counting
/// ^bb1: /// ^bb1:
/// async.runtime.await %token /// async.runtime.await %token
/// async.runtime.drop_ref %token /// async.runtime.drop_ref %token
/// br ^bb2 /// cf.br ^bb2
/// ^reference_counting: /// ^reference_counting:
/// async.runtime.drop_ref %token /// async.runtime.drop_ref %token
/// br ^bb2 /// cf.br ^bb2
/// ^bb2: /// ^bb2:
/// return /// return
/// ///
@ -208,7 +209,7 @@ private:
/// async.coro.suspend %ret, ^suspend, ^resume, ^cleanup /// async.coro.suspend %ret, ^suspend, ^resume, ^cleanup
/// ^resume: /// ^resume:
/// %0 = async.runtime.load %value /// %0 = async.runtime.load %value
/// br ^cleanup /// cf.br ^cleanup
/// ^cleanup: /// ^cleanup:
/// ... /// ...
/// ^suspend: /// ^suspend:
@ -406,7 +407,7 @@ AsyncRuntimeRefCountingPass::addDropRefInDivergentLivenessSuccessor(
refCountingBlock = &successor->getParent()->emplaceBlock(); refCountingBlock = &successor->getParent()->emplaceBlock();
refCountingBlock->moveBefore(successor); refCountingBlock->moveBefore(successor);
OpBuilder builder = OpBuilder::atBlockEnd(refCountingBlock); OpBuilder builder = OpBuilder::atBlockEnd(refCountingBlock);
builder.create<BranchOp>(value.getLoc(), successor); builder.create<cf::BranchOp>(value.getLoc(), successor);
} }
OpBuilder builder = OpBuilder::atBlockBegin(refCountingBlock); OpBuilder builder = OpBuilder::atBlockBegin(refCountingBlock);

View File

@ -12,10 +12,11 @@
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "PassDetail.h" #include "PassDetail.h"
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h" #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Async/IR/Async.h" #include "mlir/Dialect/Async/IR/Async.h"
#include "mlir/Dialect/Async/Passes.h" #include "mlir/Dialect/Async/Passes.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/BlockAndValueMapping.h"
@ -105,18 +106,18 @@ struct CoroMachinery {
/// %value = <async value> : !async.value<T> // create async value /// %value = <async value> : !async.value<T> // create async value
/// %id = async.coro.id // create a coroutine id /// %id = async.coro.id // create a coroutine id
/// %hdl = async.coro.begin %id // create a coroutine handle /// %hdl = async.coro.begin %id // create a coroutine handle
/// br ^preexisting_entry_block /// cf.br ^preexisting_entry_block
/// ///
/// /* preexisting blocks modified to branch to the cleanup block */ /// /* preexisting blocks modified to branch to the cleanup block */
/// ///
/// ^set_error: // this block created lazily only if needed (see code below) /// ^set_error: // this block created lazily only if needed (see code below)
/// async.runtime.set_error %token : !async.token /// async.runtime.set_error %token : !async.token
/// async.runtime.set_error %value : !async.value<T> /// async.runtime.set_error %value : !async.value<T>
/// br ^cleanup /// cf.br ^cleanup
/// ///
/// ^cleanup: /// ^cleanup:
/// async.coro.free %hdl // delete the coroutine state /// async.coro.free %hdl // delete the coroutine state
/// br ^suspend /// cf.br ^suspend
/// ///
/// ^suspend: /// ^suspend:
/// async.coro.end %hdl // marks the end of a coroutine /// async.coro.end %hdl // marks the end of a coroutine
@ -147,7 +148,7 @@ static CoroMachinery setupCoroMachinery(FuncOp func) {
auto coroIdOp = builder.create<CoroIdOp>(CoroIdType::get(ctx)); auto coroIdOp = builder.create<CoroIdOp>(CoroIdType::get(ctx));
auto coroHdlOp = auto coroHdlOp =
builder.create<CoroBeginOp>(CoroHandleType::get(ctx), coroIdOp.id()); builder.create<CoroBeginOp>(CoroHandleType::get(ctx), coroIdOp.id());
builder.create<BranchOp>(originalEntryBlock); builder.create<cf::BranchOp>(originalEntryBlock);
Block *cleanupBlock = func.addBlock(); Block *cleanupBlock = func.addBlock();
Block *suspendBlock = func.addBlock(); Block *suspendBlock = func.addBlock();
@ -159,7 +160,7 @@ static CoroMachinery setupCoroMachinery(FuncOp func) {
builder.create<CoroFreeOp>(coroIdOp.id(), coroHdlOp.handle()); builder.create<CoroFreeOp>(coroIdOp.id(), coroHdlOp.handle());
// Branch into the suspend block. // Branch into the suspend block.
builder.create<BranchOp>(suspendBlock); builder.create<cf::BranchOp>(suspendBlock);
// ------------------------------------------------------------------------ // // ------------------------------------------------------------------------ //
// Coroutine suspend block: mark the end of a coroutine and return allocated // Coroutine suspend block: mark the end of a coroutine and return allocated
@ -186,7 +187,7 @@ static CoroMachinery setupCoroMachinery(FuncOp func) {
Operation *terminator = block.getTerminator(); Operation *terminator = block.getTerminator();
if (auto yield = dyn_cast<YieldOp>(terminator)) { if (auto yield = dyn_cast<YieldOp>(terminator)) {
builder.setInsertionPointToEnd(&block); builder.setInsertionPointToEnd(&block);
builder.create<BranchOp>(cleanupBlock); builder.create<cf::BranchOp>(cleanupBlock);
} }
} }
@ -227,7 +228,7 @@ static Block *setupSetErrorBlock(CoroMachinery &coro) {
builder.create<RuntimeSetErrorOp>(retValue); builder.create<RuntimeSetErrorOp>(retValue);
// Branch into the cleanup block. // Branch into the cleanup block.
builder.create<BranchOp>(coro.cleanup); builder.create<cf::BranchOp>(coro.cleanup);
return coro.setError; return coro.setError;
} }
@ -305,7 +306,7 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
// Async resume operation (execution will be resumed in a thread managed by // Async resume operation (execution will be resumed in a thread managed by
// the async runtime). // the async runtime).
{ {
BranchOp branch = cast<BranchOp>(coro.entry->getTerminator()); cf::BranchOp branch = cast<cf::BranchOp>(coro.entry->getTerminator());
builder.setInsertionPointToEnd(coro.entry); builder.setInsertionPointToEnd(coro.entry);
// Save the coroutine state: async.coro.save // Save the coroutine state: async.coro.save
@ -419,8 +420,8 @@ public:
isError, builder.create<arith::ConstantOp>( isError, builder.create<arith::ConstantOp>(
loc, i1, builder.getIntegerAttr(i1, 1))); loc, i1, builder.getIntegerAttr(i1, 1)));
builder.create<AssertOp>(notError, builder.create<cf::AssertOp>(notError,
"Awaited async operand is in error state"); "Awaited async operand is in error state");
} }
// Inside the coroutine we convert await operation into coroutine suspension // Inside the coroutine we convert await operation into coroutine suspension
@ -452,11 +453,11 @@ public:
// Check if the awaited value is in the error state. // Check if the awaited value is in the error state.
builder.setInsertionPointToStart(resume); builder.setInsertionPointToStart(resume);
auto isError = builder.create<RuntimeIsErrorOp>(loc, i1, operand); auto isError = builder.create<RuntimeIsErrorOp>(loc, i1, operand);
builder.create<CondBranchOp>(isError, builder.create<cf::CondBranchOp>(isError,
/*trueDest=*/setupSetErrorBlock(coro), /*trueDest=*/setupSetErrorBlock(coro),
/*trueArgs=*/ArrayRef<Value>(), /*trueArgs=*/ArrayRef<Value>(),
/*falseDest=*/continuation, /*falseDest=*/continuation,
/*falseArgs=*/ArrayRef<Value>()); /*falseArgs=*/ArrayRef<Value>());
// Make sure that replacement value will be constructed in the // Make sure that replacement value will be constructed in the
// continuation block. // continuation block.
@ -560,18 +561,18 @@ private:
}; };
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Convert std.assert operation to cond_br into `set_error` block. // Convert std.assert operation to cf.cond_br into `set_error` block.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
class AssertOpLowering : public OpConversionPattern<AssertOp> { class AssertOpLowering : public OpConversionPattern<cf::AssertOp> {
public: public:
AssertOpLowering(MLIRContext *ctx, AssertOpLowering(MLIRContext *ctx,
llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
: OpConversionPattern<AssertOp>(ctx), : OpConversionPattern<cf::AssertOp>(ctx),
outlinedFunctions(outlinedFunctions) {} outlinedFunctions(outlinedFunctions) {}
LogicalResult LogicalResult
matchAndRewrite(AssertOp op, OpAdaptor adaptor, matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
// Check if assert operation is inside the async coroutine function. // Check if assert operation is inside the async coroutine function.
auto func = op->template getParentOfType<FuncOp>(); auto func = op->template getParentOfType<FuncOp>();
@ -585,11 +586,11 @@ public:
Block *cont = rewriter.splitBlock(op->getBlock(), Block::iterator(op)); Block *cont = rewriter.splitBlock(op->getBlock(), Block::iterator(op));
rewriter.setInsertionPointToEnd(cont->getPrevNode()); rewriter.setInsertionPointToEnd(cont->getPrevNode());
rewriter.create<CondBranchOp>(loc, adaptor.getArg(), rewriter.create<cf::CondBranchOp>(loc, adaptor.getArg(),
/*trueDest=*/cont, /*trueDest=*/cont,
/*trueArgs=*/ArrayRef<Value>(), /*trueArgs=*/ArrayRef<Value>(),
/*falseDest=*/setupSetErrorBlock(coro), /*falseDest=*/setupSetErrorBlock(coro),
/*falseArgs=*/ArrayRef<Value>()); /*falseArgs=*/ArrayRef<Value>());
rewriter.eraseOp(op); rewriter.eraseOp(op);
return success(); return success();
@ -765,7 +766,7 @@ void AsyncToAsyncRuntimePass::runOnOperation() {
// and we have to make sure that structured control flow operations with async // and we have to make sure that structured control flow operations with async
// operations in nested regions will be converted to branch-based control flow // operations in nested regions will be converted to branch-based control flow
// before we add the coroutine basic blocks. // before we add the coroutine basic blocks.
populateLoopToStdConversionPatterns(asyncPatterns); populateSCFToControlFlowConversionPatterns(asyncPatterns);
// Async lowering does not use type converter because it must preserve all // Async lowering does not use type converter because it must preserve all
// types for async.runtime operations. // types for async.runtime operations.
@ -792,14 +793,15 @@ void AsyncToAsyncRuntimePass::runOnOperation() {
}); });
return !walkResult.wasInterrupted(); return !walkResult.wasInterrupted();
}); });
runtimeTarget.addLegalOp<AssertOp, arith::XOrIOp, arith::ConstantOp, runtimeTarget.addLegalOp<cf::AssertOp, arith::XOrIOp, arith::ConstantOp,
ConstantOp, BranchOp, CondBranchOp>(); ConstantOp, cf::BranchOp, cf::CondBranchOp>();
// Assertions must be converted to runtime errors inside async functions. // Assertions must be converted to runtime errors inside async functions.
runtimeTarget.addDynamicallyLegalOp<AssertOp>([&](AssertOp op) -> bool { runtimeTarget.addDynamicallyLegalOp<cf::AssertOp>(
auto func = op->getParentOfType<FuncOp>(); [&](cf::AssertOp op) -> bool {
return outlinedFunctions.find(func) == outlinedFunctions.end(); auto func = op->getParentOfType<FuncOp>();
}); return outlinedFunctions.find(func) == outlinedFunctions.end();
});
if (eliminateBlockingAwaitOps) if (eliminateBlockingAwaitOps)
runtimeTarget.addDynamicallyLegalOp<RuntimeAwaitOp>( runtimeTarget.addDynamicallyLegalOp<RuntimeAwaitOp>(

View File

@ -17,7 +17,7 @@ add_mlir_dialect_library(MLIRAsyncTransforms
MLIRIR MLIRIR
MLIRPass MLIRPass
MLIRSCF MLIRSCF
MLIRSCFToStandard MLIRSCFToControlFlow
MLIRStandard MLIRStandard
MLIRTransforms MLIRTransforms
MLIRTransformUtils MLIRTransformUtils

View File

@ -18,12 +18,12 @@
// (using the BufferViewFlowAnalysis class). Consider the following example: // (using the BufferViewFlowAnalysis class). Consider the following example:
// //
// ^bb0(%arg0): // ^bb0(%arg0):
// cond_br %cond, ^bb1, ^bb2 // cf.cond_br %cond, ^bb1, ^bb2
// ^bb1: // ^bb1:
// br ^exit(%arg0) // cf.br ^exit(%arg0)
// ^bb2: // ^bb2:
// %new_value = ... // %new_value = ...
// br ^exit(%new_value) // cf.br ^exit(%new_value)
// ^exit(%arg1): // ^exit(%arg1):
// return %arg1; // return %arg1;
// //

View File

@ -6,6 +6,7 @@ add_subdirectory(Async)
add_subdirectory(AMX) add_subdirectory(AMX)
add_subdirectory(Bufferization) add_subdirectory(Bufferization)
add_subdirectory(Complex) add_subdirectory(Complex)
add_subdirectory(ControlFlow)
add_subdirectory(DLTI) add_subdirectory(DLTI)
add_subdirectory(EmitC) add_subdirectory(EmitC)
add_subdirectory(GPU) add_subdirectory(GPU)

View File

@ -0,0 +1 @@
add_subdirectory(IR)

View File

@ -0,0 +1,15 @@
add_mlir_dialect_library(MLIRControlFlow
ControlFlowOps.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ControlFlow/IR
DEPENDS
MLIRControlFlowOpsIncGen
LINK_LIBS PUBLIC
MLIRArithmetic
MLIRControlFlowInterfaces
MLIRIR
MLIRSideEffectInterfaces
)

View File

@ -0,0 +1,891 @@
//===- ControlFlowOps.cpp - ControlFlow Operations ------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/CommonFolders.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/MathExtras.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
#include <numeric>
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOpsDialect.cpp.inc"
using namespace mlir;
using namespace mlir::cf;
//===----------------------------------------------------------------------===//
// ControlFlowDialect Interfaces
//===----------------------------------------------------------------------===//
namespace {
/// This class defines the interface for handling inlining with control flow
/// operations.
struct ControlFlowInlinerInterface : public DialectInlinerInterface {
using DialectInlinerInterface::DialectInlinerInterface;
~ControlFlowInlinerInterface() override = default;
/// All control flow operations can be inlined.
bool isLegalToInline(Operation *call, Operation *callable,
bool wouldBeCloned) const final {
return true;
}
bool isLegalToInline(Operation *, Region *, bool,
BlockAndValueMapping &) const final {
return true;
}
/// ControlFlow terminator operations don't really need any special handing.
void handleTerminator(Operation *op, Block *newDest) const final {}
};
} // namespace
//===----------------------------------------------------------------------===//
// ControlFlowDialect
//===----------------------------------------------------------------------===//
void ControlFlowDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
>();
addInterfaces<ControlFlowInlinerInterface>();
}
//===----------------------------------------------------------------------===//
// AssertOp
//===----------------------------------------------------------------------===//
LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) {
// Erase assertion if argument is constant true.
if (matchPattern(op.getArg(), m_One())) {
rewriter.eraseOp(op);
return success();
}
return failure();
}
//===----------------------------------------------------------------------===//
// BranchOp
//===----------------------------------------------------------------------===//
/// Given a successor, try to collapse it to a new destination if it only
/// contains a passthrough unconditional branch. If the successor is
/// collapsable, `successor` and `successorOperands` are updated to reference
/// the new destination and values. `argStorage` is used as storage if operands
/// to the collapsed successor need to be remapped. It must outlive uses of
/// successorOperands.
static LogicalResult collapseBranch(Block *&successor,
ValueRange &successorOperands,
SmallVectorImpl<Value> &argStorage) {
// Check that the successor only contains a unconditional branch.
if (std::next(successor->begin()) != successor->end())
return failure();
// Check that the terminator is an unconditional branch.
BranchOp successorBranch = dyn_cast<BranchOp>(successor->getTerminator());
if (!successorBranch)
return failure();
// Check that the arguments are only used within the terminator.
for (BlockArgument arg : successor->getArguments()) {
for (Operation *user : arg.getUsers())
if (user != successorBranch)
return failure();
}
// Don't try to collapse branches to infinite loops.
Block *successorDest = successorBranch.getDest();
if (successorDest == successor)
return failure();
// Update the operands to the successor. If the branch parent has no
// arguments, we can use the branch operands directly.
OperandRange operands = successorBranch.getOperands();
if (successor->args_empty()) {
successor = successorDest;
successorOperands = operands;
return success();
}
// Otherwise, we need to remap any argument operands.
for (Value operand : operands) {
BlockArgument argOperand = operand.dyn_cast<BlockArgument>();
if (argOperand && argOperand.getOwner() == successor)
argStorage.push_back(successorOperands[argOperand.getArgNumber()]);
else
argStorage.push_back(operand);
}
successor = successorDest;
successorOperands = argStorage;
return success();
}
/// Simplify a branch to a block that has a single predecessor. This effectively
/// merges the two blocks.
static LogicalResult
simplifyBrToBlockWithSinglePred(BranchOp op, PatternRewriter &rewriter) {
// Check that the successor block has a single predecessor.
Block *succ = op.getDest();
Block *opParent = op->getBlock();
if (succ == opParent || !llvm::hasSingleElement(succ->getPredecessors()))
return failure();
// Merge the successor into the current block and erase the branch.
rewriter.mergeBlocks(succ, opParent, op.getOperands());
rewriter.eraseOp(op);
return success();
}
/// br ^bb1
/// ^bb1
/// br ^bbN(...)
///
/// -> br ^bbN(...)
///
static LogicalResult simplifyPassThroughBr(BranchOp op,
PatternRewriter &rewriter) {
Block *dest = op.getDest();
ValueRange destOperands = op.getOperands();
SmallVector<Value, 4> destOperandStorage;
// Try to collapse the successor if it points somewhere other than this
// block.
if (dest == op->getBlock() ||
failed(collapseBranch(dest, destOperands, destOperandStorage)))
return failure();
// Create a new branch with the collapsed successor.
rewriter.replaceOpWithNewOp<BranchOp>(op, dest, destOperands);
return success();
}
LogicalResult BranchOp::canonicalize(BranchOp op, PatternRewriter &rewriter) {
return success(succeeded(simplifyBrToBlockWithSinglePred(op, rewriter)) ||
succeeded(simplifyPassThroughBr(op, rewriter)));
}
void BranchOp::setDest(Block *block) { return setSuccessor(block); }
void BranchOp::eraseOperand(unsigned index) { (*this)->eraseOperand(index); }
Optional<MutableOperandRange>
BranchOp::getMutableSuccessorOperands(unsigned index) {
assert(index == 0 && "invalid successor index");
return getDestOperandsMutable();
}
Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) {
return getDest();
}
//===----------------------------------------------------------------------===//
// CondBranchOp
//===----------------------------------------------------------------------===//
namespace {
/// cf.cond_br true, ^bb1, ^bb2
/// -> br ^bb1
/// cf.cond_br false, ^bb1, ^bb2
/// -> br ^bb2
///
struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
using OpRewritePattern<CondBranchOp>::OpRewritePattern;
LogicalResult matchAndRewrite(CondBranchOp condbr,
PatternRewriter &rewriter) const override {
if (matchPattern(condbr.getCondition(), m_NonZero())) {
// True branch taken.
rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
condbr.getTrueOperands());
return success();
}
if (matchPattern(condbr.getCondition(), m_Zero())) {
// False branch taken.
rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
condbr.getFalseOperands());
return success();
}
return failure();
}
};
/// cf.cond_br %cond, ^bb1, ^bb2
/// ^bb1
/// br ^bbN(...)
/// ^bb2
/// br ^bbK(...)
///
/// -> cf.cond_br %cond, ^bbN(...), ^bbK(...)
///
struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> {
using OpRewritePattern<CondBranchOp>::OpRewritePattern;
LogicalResult matchAndRewrite(CondBranchOp condbr,
PatternRewriter &rewriter) const override {
Block *trueDest = condbr.getTrueDest(), *falseDest = condbr.getFalseDest();
ValueRange trueDestOperands = condbr.getTrueOperands();
ValueRange falseDestOperands = condbr.getFalseOperands();
SmallVector<Value, 4> trueDestOperandStorage, falseDestOperandStorage;
// Try to collapse one of the current successors.
LogicalResult collapsedTrue =
collapseBranch(trueDest, trueDestOperands, trueDestOperandStorage);
LogicalResult collapsedFalse =
collapseBranch(falseDest, falseDestOperands, falseDestOperandStorage);
if (failed(collapsedTrue) && failed(collapsedFalse))
return failure();
// Create a new branch with the collapsed successors.
rewriter.replaceOpWithNewOp<CondBranchOp>(condbr, condbr.getCondition(),
trueDest, trueDestOperands,
falseDest, falseDestOperands);
return success();
}
};
/// cf.cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N)
/// -> br ^bb1(A, ..., N)
///
/// cf.cond_br %cond, ^bb1(A), ^bb1(B)
/// -> %select = arith.select %cond, A, B
/// br ^bb1(%select)
///
struct SimplifyCondBranchIdenticalSuccessors
: public OpRewritePattern<CondBranchOp> {
using OpRewritePattern<CondBranchOp>::OpRewritePattern;
LogicalResult matchAndRewrite(CondBranchOp condbr,
PatternRewriter &rewriter) const override {
// Check that the true and false destinations are the same and have the same
// operands.
Block *trueDest = condbr.getTrueDest();
if (trueDest != condbr.getFalseDest())
return failure();
// If all of the operands match, no selects need to be generated.
OperandRange trueOperands = condbr.getTrueOperands();
OperandRange falseOperands = condbr.getFalseOperands();
if (trueOperands == falseOperands) {
rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, trueOperands);
return success();
}
// Otherwise, if the current block is the only predecessor insert selects
// for any mismatched branch operands.
if (trueDest->getUniquePredecessor() != condbr->getBlock())
return failure();
// Generate a select for any operands that differ between the two.
SmallVector<Value, 8> mergedOperands;
mergedOperands.reserve(trueOperands.size());
Value condition = condbr.getCondition();
for (auto it : llvm::zip(trueOperands, falseOperands)) {
if (std::get<0>(it) == std::get<1>(it))
mergedOperands.push_back(std::get<0>(it));
else
mergedOperands.push_back(rewriter.create<arith::SelectOp>(
condbr.getLoc(), condition, std::get<0>(it), std::get<1>(it)));
}
rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, mergedOperands);
return success();
}
};
/// ...
/// cf.cond_br %cond, ^bb1(...), ^bb2(...)
/// ...
/// ^bb1: // has single predecessor
/// ...
/// cf.cond_br %cond, ^bb3(...), ^bb4(...)
///
/// ->
///
/// ...
/// cf.cond_br %cond, ^bb1(...), ^bb2(...)
/// ...
/// ^bb1: // has single predecessor
/// ...
/// br ^bb3(...)
///
struct SimplifyCondBranchFromCondBranchOnSameCondition
: public OpRewritePattern<CondBranchOp> {
using OpRewritePattern<CondBranchOp>::OpRewritePattern;
LogicalResult matchAndRewrite(CondBranchOp condbr,
PatternRewriter &rewriter) const override {
// Check that we have a single distinct predecessor.
Block *currentBlock = condbr->getBlock();
Block *predecessor = currentBlock->getSinglePredecessor();
if (!predecessor)
return failure();
// Check that the predecessor terminates with a conditional branch to this
// block and that it branches on the same condition.
auto predBranch = dyn_cast<CondBranchOp>(predecessor->getTerminator());
if (!predBranch || condbr.getCondition() != predBranch.getCondition())
return failure();
// Fold this branch to an unconditional branch.
if (currentBlock == predBranch.getTrueDest())
rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
condbr.getTrueDestOperands());
else
rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
condbr.getFalseDestOperands());
return success();
}
};
/// cf.cond_br %arg0, ^trueB, ^falseB
///
/// ^trueB:
/// "test.consumer1"(%arg0) : (i1) -> ()
/// ...
///
/// ^falseB:
/// "test.consumer2"(%arg0) : (i1) -> ()
/// ...
///
/// ->
///
/// cf.cond_br %arg0, ^trueB, ^falseB
/// ^trueB:
/// "test.consumer1"(%true) : (i1) -> ()
/// ...
///
/// ^falseB:
/// "test.consumer2"(%false) : (i1) -> ()
/// ...
struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> {
using OpRewritePattern<CondBranchOp>::OpRewritePattern;
LogicalResult matchAndRewrite(CondBranchOp condbr,
PatternRewriter &rewriter) const override {
// Check that we have a single distinct predecessor.
bool replaced = false;
Type ty = rewriter.getI1Type();
// These variables serve to prevent creating duplicate constants
// and hold constant true or false values.
Value constantTrue = nullptr;
Value constantFalse = nullptr;
// TODO These checks can be expanded to encompas any use with only
// either the true of false edge as a predecessor. For now, we fall
// back to checking the single predecessor is given by the true/fasle
// destination, thereby ensuring that only that edge can reach the
// op.
if (condbr.getTrueDest()->getSinglePredecessor()) {
for (OpOperand &use :
llvm::make_early_inc_range(condbr.getCondition().getUses())) {
if (use.getOwner()->getBlock() == condbr.getTrueDest()) {
replaced = true;
if (!constantTrue)
constantTrue = rewriter.create<arith::ConstantOp>(
condbr.getLoc(), ty, rewriter.getBoolAttr(true));
rewriter.updateRootInPlace(use.getOwner(),
[&] { use.set(constantTrue); });
}
}
}
if (condbr.getFalseDest()->getSinglePredecessor()) {
for (OpOperand &use :
llvm::make_early_inc_range(condbr.getCondition().getUses())) {
if (use.getOwner()->getBlock() == condbr.getFalseDest()) {
replaced = true;
if (!constantFalse)
constantFalse = rewriter.create<arith::ConstantOp>(
condbr.getLoc(), ty, rewriter.getBoolAttr(false));
rewriter.updateRootInPlace(use.getOwner(),
[&] { use.set(constantFalse); });
}
}
}
return success(replaced);
}
};
} // namespace
void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
SimplifyCondBranchIdenticalSuccessors,
SimplifyCondBranchFromCondBranchOnSameCondition,
CondBranchTruthPropagation>(context);
}
Optional<MutableOperandRange>
CondBranchOp::getMutableSuccessorOperands(unsigned index) {
assert(index < getNumSuccessors() && "invalid successor index");
return index == trueIndex ? getTrueDestOperandsMutable()
: getFalseDestOperandsMutable();
}
Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
if (IntegerAttr condAttr = operands.front().dyn_cast_or_null<IntegerAttr>())
return condAttr.getValue().isOneValue() ? getTrueDest() : getFalseDest();
return nullptr;
}
//===----------------------------------------------------------------------===//
// SwitchOp
//===----------------------------------------------------------------------===//
void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
Block *defaultDestination, ValueRange defaultOperands,
DenseIntElementsAttr caseValues,
BlockRange caseDestinations,
ArrayRef<ValueRange> caseOperands) {
build(builder, result, value, defaultOperands, caseOperands, caseValues,
defaultDestination, caseDestinations);
}
void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
Block *defaultDestination, ValueRange defaultOperands,
ArrayRef<APInt> caseValues, BlockRange caseDestinations,
ArrayRef<ValueRange> caseOperands) {
DenseIntElementsAttr caseValuesAttr;
if (!caseValues.empty()) {
ShapedType caseValueType = VectorType::get(
static_cast<int64_t>(caseValues.size()), value.getType());
caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
}
build(builder, result, value, defaultDestination, defaultOperands,
caseValuesAttr, caseDestinations, caseOperands);
}
/// <cases> ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)?
/// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )*
static ParseResult parseSwitchOpCases(
OpAsmParser &parser, Type &flagType, Block *&defaultDestination,
SmallVectorImpl<OpAsmParser::OperandType> &defaultOperands,
SmallVectorImpl<Type> &defaultOperandTypes,
DenseIntElementsAttr &caseValues,
SmallVectorImpl<Block *> &caseDestinations,
SmallVectorImpl<SmallVector<OpAsmParser::OperandType>> &caseOperands,
SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) {
if (parser.parseKeyword("default") || parser.parseColon() ||
parser.parseSuccessor(defaultDestination))
return failure();
if (succeeded(parser.parseOptionalLParen())) {
if (parser.parseRegionArgumentList(defaultOperands) ||
parser.parseColonTypeList(defaultOperandTypes) || parser.parseRParen())
return failure();
}
SmallVector<APInt> values;
unsigned bitWidth = flagType.getIntOrFloatBitWidth();
while (succeeded(parser.parseOptionalComma())) {
int64_t value = 0;
if (failed(parser.parseInteger(value)))
return failure();
values.push_back(APInt(bitWidth, value));
Block *destination;
SmallVector<OpAsmParser::OperandType> operands;
SmallVector<Type> operandTypes;
if (failed(parser.parseColon()) ||
failed(parser.parseSuccessor(destination)))
return failure();
if (succeeded(parser.parseOptionalLParen())) {
if (failed(parser.parseRegionArgumentList(operands)) ||
failed(parser.parseColonTypeList(operandTypes)) ||
failed(parser.parseRParen()))
return failure();
}
caseDestinations.push_back(destination);
caseOperands.emplace_back(operands);
caseOperandTypes.emplace_back(operandTypes);
}
if (!values.empty()) {
ShapedType caseValueType =
VectorType::get(static_cast<int64_t>(values.size()), flagType);
caseValues = DenseIntElementsAttr::get(caseValueType, values);
}
return success();
}
static void printSwitchOpCases(
OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination,
OperandRange defaultOperands, TypeRange defaultOperandTypes,
DenseIntElementsAttr caseValues, SuccessorRange caseDestinations,
OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes) {
p << " default: ";
p.printSuccessorAndUseList(defaultDestination, defaultOperands);
if (!caseValues)
return;
for (const auto &it : llvm::enumerate(caseValues.getValues<APInt>())) {
p << ',';
p.printNewline();
p << " ";
p << it.value().getLimitedValue();
p << ": ";
p.printSuccessorAndUseList(caseDestinations[it.index()],
caseOperands[it.index()]);
}
p.printNewline();
}
LogicalResult SwitchOp::verify() {
auto caseValues = getCaseValues();
auto caseDestinations = getCaseDestinations();
if (!caseValues && caseDestinations.empty())
return success();
Type flagType = getFlag().getType();
Type caseValueType = caseValues->getType().getElementType();
if (caseValueType != flagType)
return emitOpError() << "'flag' type (" << flagType
<< ") should match case value type (" << caseValueType
<< ")";
if (caseValues &&
caseValues->size() != static_cast<int64_t>(caseDestinations.size()))
return emitOpError() << "number of case values (" << caseValues->size()
<< ") should match number of "
"case destinations ("
<< caseDestinations.size() << ")";
return success();
}
Optional<MutableOperandRange>
SwitchOp::getMutableSuccessorOperands(unsigned index) {
assert(index < getNumSuccessors() && "invalid successor index");
return index == 0 ? getDefaultOperandsMutable()
: getCaseOperandsMutable(index - 1);
}
Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
Optional<DenseIntElementsAttr> caseValues = getCaseValues();
if (!caseValues)
return getDefaultDestination();
SuccessorRange caseDests = getCaseDestinations();
if (auto value = operands.front().dyn_cast_or_null<IntegerAttr>()) {
for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>()))
if (it.value() == value.getValue())
return caseDests[it.index()];
return getDefaultDestination();
}
return nullptr;
}
/// switch %flag : i32, [
/// default: ^bb1
/// ]
/// -> br ^bb1
static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op,
PatternRewriter &rewriter) {
if (!op.getCaseDestinations().empty())
return failure();
rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
op.getDefaultOperands());
return success();
}
/// switch %flag : i32, [
/// default: ^bb1,
/// 42: ^bb1,
/// 43: ^bb2
/// ]
/// ->
/// switch %flag : i32, [
/// default: ^bb1,
/// 43: ^bb2
/// ]
static LogicalResult
dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter) {
SmallVector<Block *> newCaseDestinations;
SmallVector<ValueRange> newCaseOperands;
SmallVector<APInt> newCaseValues;
bool requiresChange = false;
auto caseValues = op.getCaseValues();
auto caseDests = op.getCaseDestinations();
for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
if (caseDests[it.index()] == op.getDefaultDestination() &&
op.getCaseOperands(it.index()) == op.getDefaultOperands()) {
requiresChange = true;
continue;
}
newCaseDestinations.push_back(caseDests[it.index()]);
newCaseOperands.push_back(op.getCaseOperands(it.index()));
newCaseValues.push_back(it.value());
}
if (!requiresChange)
return failure();
rewriter.replaceOpWithNewOp<SwitchOp>(
op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
newCaseValues, newCaseDestinations, newCaseOperands);
return success();
}
/// Helper for folding a switch with a constant value.
/// switch %c_42 : i32, [
/// default: ^bb1 ,
/// 42: ^bb2,
/// 43: ^bb3
/// ]
/// -> br ^bb2
static void foldSwitch(SwitchOp op, PatternRewriter &rewriter,
const APInt &caseValue) {
auto caseValues = op.getCaseValues();
for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
if (it.value() == caseValue) {
rewriter.replaceOpWithNewOp<BranchOp>(
op, op.getCaseDestinations()[it.index()],
op.getCaseOperands(it.index()));
return;
}
}
rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
op.getDefaultOperands());
}
/// switch %c_42 : i32, [
/// default: ^bb1,
/// 42: ^bb2,
/// 43: ^bb3
/// ]
/// -> br ^bb2
static LogicalResult simplifyConstSwitchValue(SwitchOp op,
PatternRewriter &rewriter) {
APInt caseValue;
if (!matchPattern(op.getFlag(), m_ConstantInt(&caseValue)))
return failure();
foldSwitch(op, rewriter, caseValue);
return success();
}
/// switch %c_42 : i32, [
/// default: ^bb1,
/// 42: ^bb2,
/// ]
/// ^bb2:
/// br ^bb3
/// ->
/// switch %c_42 : i32, [
/// default: ^bb1,
/// 42: ^bb3,
/// ]
static LogicalResult simplifyPassThroughSwitch(SwitchOp op,
PatternRewriter &rewriter) {
SmallVector<Block *> newCaseDests;
SmallVector<ValueRange> newCaseOperands;
SmallVector<SmallVector<Value>> argStorage;
auto caseValues = op.getCaseValues();
auto caseDests = op.getCaseDestinations();
bool requiresChange = false;
for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
Block *caseDest = caseDests[i];
ValueRange caseOperands = op.getCaseOperands(i);
argStorage.emplace_back();
if (succeeded(collapseBranch(caseDest, caseOperands, argStorage.back())))
requiresChange = true;
newCaseDests.push_back(caseDest);
newCaseOperands.push_back(caseOperands);
}
Block *defaultDest = op.getDefaultDestination();
ValueRange defaultOperands = op.getDefaultOperands();
argStorage.emplace_back();
if (succeeded(
collapseBranch(defaultDest, defaultOperands, argStorage.back())))
requiresChange = true;
if (!requiresChange)
return failure();
rewriter.replaceOpWithNewOp<SwitchOp>(op, op.getFlag(), defaultDest,
defaultOperands, caseValues.getValue(),
newCaseDests, newCaseOperands);
return success();
}
/// switch %flag : i32, [
/// default: ^bb1,
/// 42: ^bb2,
/// ]
/// ^bb2:
/// switch %flag : i32, [
/// default: ^bb3,
/// 42: ^bb4
/// ]
/// ->
/// switch %flag : i32, [
/// default: ^bb1,
/// 42: ^bb2,
/// ]
/// ^bb2:
/// br ^bb4
///
/// and
///
/// switch %flag : i32, [
/// default: ^bb1,
/// 42: ^bb2,
/// ]
/// ^bb2:
/// switch %flag : i32, [
/// default: ^bb3,
/// 43: ^bb4
/// ]
/// ->
/// switch %flag : i32, [
/// default: ^bb1,
/// 42: ^bb2,
/// ]
/// ^bb2:
/// br ^bb3
static LogicalResult
simplifySwitchFromSwitchOnSameCondition(SwitchOp op,
PatternRewriter &rewriter) {
// Check that we have a single distinct predecessor.
Block *currentBlock = op->getBlock();
Block *predecessor = currentBlock->getSinglePredecessor();
if (!predecessor)
return failure();
// Check that the predecessor terminates with a switch branch to this block
// and that it branches on the same condition and that this branch isn't the
// default destination.
auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
predSwitch.getDefaultDestination() == currentBlock)
return failure();
// Fold this switch to an unconditional branch.
SuccessorRange predDests = predSwitch.getCaseDestinations();
auto it = llvm::find(predDests, currentBlock);
if (it != predDests.end()) {
Optional<DenseIntElementsAttr> predCaseValues = predSwitch.getCaseValues();
foldSwitch(op, rewriter,
predCaseValues->getValues<APInt>()[it - predDests.begin()]);
} else {
rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
op.getDefaultOperands());
}
return success();
}
/// switch %flag : i32, [
/// default: ^bb1,
/// 42: ^bb2
/// ]
/// ^bb1:
/// switch %flag : i32, [
/// default: ^bb3,
/// 42: ^bb4,
/// 43: ^bb5
/// ]
/// ->
/// switch %flag : i32, [
/// default: ^bb1,
/// 42: ^bb2,
/// ]
/// ^bb1:
/// switch %flag : i32, [
/// default: ^bb3,
/// 43: ^bb5
/// ]
static LogicalResult
simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op,
PatternRewriter &rewriter) {
// Check that we have a single distinct predecessor.
Block *currentBlock = op->getBlock();
Block *predecessor = currentBlock->getSinglePredecessor();
if (!predecessor)
return failure();
// Check that the predecessor terminates with a switch branch to this block
// and that it branches on the same condition and that this branch is the
// default destination.
auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
predSwitch.getDefaultDestination() != currentBlock)
return failure();
// Delete case values that are not possible here.
DenseSet<APInt> caseValuesToRemove;
auto predDests = predSwitch.getCaseDestinations();
auto predCaseValues = predSwitch.getCaseValues();
for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i)
if (currentBlock != predDests[i])
caseValuesToRemove.insert(predCaseValues->getValues<APInt>()[i]);
SmallVector<Block *> newCaseDestinations;
SmallVector<ValueRange> newCaseOperands;
SmallVector<APInt> newCaseValues;
bool requiresChange = false;
auto caseValues = op.getCaseValues();
auto caseDests = op.getCaseDestinations();
for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
if (caseValuesToRemove.contains(it.value())) {
requiresChange = true;
continue;
}
newCaseDestinations.push_back(caseDests[it.index()]);
newCaseOperands.push_back(op.getCaseOperands(it.index()));
newCaseValues.push_back(it.value());
}
if (!requiresChange)
return failure();
rewriter.replaceOpWithNewOp<SwitchOp>(
op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
newCaseValues, newCaseDestinations, newCaseOperands);
return success();
}
void SwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add(&simplifySwitchWithOnlyDefault)
.add(&dropSwitchCasesThatMatchDefault)
.add(&simplifyConstSwitchValue)
.add(&simplifyPassThroughSwitch)
.add(&simplifySwitchFromSwitchOnSameCondition)
.add(&simplifySwitchFromDefaultSwitchOnSameCondition);
}
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"

View File

@ -12,10 +12,10 @@
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/GPU/Passes.h" #include "mlir/Dialect/GPU/Passes.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
@ -44,14 +44,14 @@ struct GpuAllReduceRewriter {
/// workgroup memory. /// workgroup memory.
/// ///
/// %subgroup_reduce = `createSubgroupReduce(%operand)` /// %subgroup_reduce = `createSubgroupReduce(%operand)`
/// cond_br %is_first_lane, ^then1, ^continue1 /// cf.cond_br %is_first_lane, ^then1, ^continue1
/// ^then1: /// ^then1:
/// store %subgroup_reduce, %workgroup_buffer[%subgroup_id] /// store %subgroup_reduce, %workgroup_buffer[%subgroup_id]
/// br ^continue1 /// cf.br ^continue1
/// ^continue1: /// ^continue1:
/// gpu.barrier /// gpu.barrier
/// %is_valid_subgroup = arith.cmpi "slt" %invocation_idx, %num_subgroups /// %is_valid_subgroup = arith.cmpi "slt" %invocation_idx, %num_subgroups
/// cond_br %is_valid_subgroup, ^then2, ^continue2 /// cf.cond_br %is_valid_subgroup, ^then2, ^continue2
/// ^then2: /// ^then2:
/// %partial_reduce = load %workgroup_buffer[%invocation_idx] /// %partial_reduce = load %workgroup_buffer[%invocation_idx]
/// %all_reduce = `createSubgroupReduce(%partial_reduce)` /// %all_reduce = `createSubgroupReduce(%partial_reduce)`
@ -194,7 +194,7 @@ private:
// Add branch before inserted body, into body. // Add branch before inserted body, into body.
block = block->getNextNode(); block = block->getNextNode();
create<BranchOp>(block, ValueRange()); create<cf::BranchOp>(block, ValueRange());
// Replace all gpu.yield ops with branch out of body. // Replace all gpu.yield ops with branch out of body.
for (; block != split; block = block->getNextNode()) { for (; block != split; block = block->getNextNode()) {
@ -202,7 +202,7 @@ private:
if (!isa<gpu::YieldOp>(terminator)) if (!isa<gpu::YieldOp>(terminator))
continue; continue;
rewriter.setInsertionPointToEnd(block); rewriter.setInsertionPointToEnd(block);
rewriter.replaceOpWithNewOp<BranchOp>( rewriter.replaceOpWithNewOp<cf::BranchOp>(
terminator, split, ValueRange(terminator->getOperand(0))); terminator, split, ValueRange(terminator->getOperand(0)));
} }
@ -285,17 +285,17 @@ private:
Block *continueBlock = rewriter.splitBlock(elseBlock, elseBlock->begin()); Block *continueBlock = rewriter.splitBlock(elseBlock, elseBlock->begin());
rewriter.setInsertionPointToEnd(currentBlock); rewriter.setInsertionPointToEnd(currentBlock);
create<CondBranchOp>(condition, thenBlock, create<cf::CondBranchOp>(condition, thenBlock,
/*trueOperands=*/ArrayRef<Value>(), elseBlock, /*trueOperands=*/ArrayRef<Value>(), elseBlock,
/*falseOperands=*/ArrayRef<Value>()); /*falseOperands=*/ArrayRef<Value>());
rewriter.setInsertionPointToStart(thenBlock); rewriter.setInsertionPointToStart(thenBlock);
auto thenOperands = thenOpsFactory(); auto thenOperands = thenOpsFactory();
create<BranchOp>(continueBlock, thenOperands); create<cf::BranchOp>(continueBlock, thenOperands);
rewriter.setInsertionPointToStart(elseBlock); rewriter.setInsertionPointToStart(elseBlock);
auto elseOperands = elseOpsFactory(); auto elseOperands = elseOpsFactory();
create<BranchOp>(continueBlock, elseOperands); create<cf::BranchOp>(continueBlock, elseOperands);
assert(thenOperands.size() == elseOperands.size()); assert(thenOperands.size() == elseOperands.size());
rewriter.setInsertionPointToStart(continueBlock); rewriter.setInsertionPointToStart(continueBlock);

View File

@ -12,6 +12,7 @@
#include "PassDetail.h" #include "PassDetail.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/GPU/Passes.h" #include "mlir/Dialect/GPU/Passes.h"
@ -186,7 +187,7 @@ static gpu::GPUFuncOp outlineKernelFuncImpl(gpu::LaunchOp launchOp,
Block &launchOpEntry = launchOpBody.front(); Block &launchOpEntry = launchOpBody.front();
Block *clonedLaunchOpEntry = map.lookup(&launchOpEntry); Block *clonedLaunchOpEntry = map.lookup(&launchOpEntry);
builder.setInsertionPointToEnd(&entryBlock); builder.setInsertionPointToEnd(&entryBlock);
builder.create<BranchOp>(loc, clonedLaunchOpEntry); builder.create<cf::BranchOp>(loc, clonedLaunchOpEntry);
outlinedFunc.walk([](gpu::TerminatorOp op) { outlinedFunc.walk([](gpu::TerminatorOp op) {
OpBuilder replacer(op); OpBuilder replacer(op);

View File

@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "PassDetail.h" #include "PassDetail.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
@ -254,13 +255,13 @@ struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
DenseSet<BlockArgument> &blockArgsToDetensor) override { DenseSet<BlockArgument> &blockArgsToDetensor) override {
SmallVector<Value> workList; SmallVector<Value> workList;
func->walk([&](CondBranchOp condBr) { func->walk([&](cf::CondBranchOp condBr) {
for (auto operand : condBr.getOperands()) { for (auto operand : condBr.getOperands()) {
workList.push_back(operand); workList.push_back(operand);
} }
}); });
func->walk([&](BranchOp br) { func->walk([&](cf::BranchOp br) {
for (auto operand : br.getOperands()) { for (auto operand : br.getOperands()) {
workList.push_back(operand); workList.push_back(operand);
} }

View File

@ -9,6 +9,7 @@
#include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Matchers.h" #include "mlir/IR/Matchers.h"
@ -165,13 +166,13 @@ struct SingleBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
// "test.foo"() : () -> () // "test.foo"() : () -> ()
// %v = scf.execute_region -> i64 { // %v = scf.execute_region -> i64 {
// %c = "test.cmp"() : () -> i1 // %c = "test.cmp"() : () -> i1
// cond_br %c, ^bb2, ^bb3 // cf.cond_br %c, ^bb2, ^bb3
// ^bb2: // ^bb2:
// %x = "test.val1"() : () -> i64 // %x = "test.val1"() : () -> i64
// br ^bb4(%x : i64) // cf.br ^bb4(%x : i64)
// ^bb3: // ^bb3:
// %y = "test.val2"() : () -> i64 // %y = "test.val2"() : () -> i64
// br ^bb4(%y : i64) // cf.br ^bb4(%y : i64)
// ^bb4(%z : i64): // ^bb4(%z : i64):
// scf.yield %z : i64 // scf.yield %z : i64
// } // }
@ -184,13 +185,13 @@ struct SingleBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
// func @func_execute_region_elim() { // func @func_execute_region_elim() {
// "test.foo"() : () -> () // "test.foo"() : () -> ()
// %c = "test.cmp"() : () -> i1 // %c = "test.cmp"() : () -> i1
// cond_br %c, ^bb1, ^bb2 // cf.cond_br %c, ^bb1, ^bb2
// ^bb1: // pred: ^bb0 // ^bb1: // pred: ^bb0
// %x = "test.val1"() : () -> i64 // %x = "test.val1"() : () -> i64
// br ^bb3(%x : i64) // cf.br ^bb3(%x : i64)
// ^bb2: // pred: ^bb0 // ^bb2: // pred: ^bb0
// %y = "test.val2"() : () -> i64 // %y = "test.val2"() : () -> i64
// br ^bb3(%y : i64) // cf.br ^bb3(%y : i64)
// ^bb3(%z: i64): // 2 preds: ^bb1, ^bb2 // ^bb3(%z: i64): // 2 preds: ^bb1, ^bb2
// "test.bar"(%z) : (i64) -> () // "test.bar"(%z) : (i64) -> ()
// return // return
@ -208,13 +209,13 @@ struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
Block *postBlock = rewriter.splitBlock(prevBlock, op->getIterator()); Block *postBlock = rewriter.splitBlock(prevBlock, op->getIterator());
rewriter.setInsertionPointToEnd(prevBlock); rewriter.setInsertionPointToEnd(prevBlock);
rewriter.create<BranchOp>(op.getLoc(), &op.getRegion().front()); rewriter.create<cf::BranchOp>(op.getLoc(), &op.getRegion().front());
for (Block &blk : op.getRegion()) { for (Block &blk : op.getRegion()) {
if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) { if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) {
rewriter.setInsertionPoint(yieldOp); rewriter.setInsertionPoint(yieldOp);
rewriter.create<BranchOp>(yieldOp.getLoc(), postBlock, rewriter.create<cf::BranchOp>(yieldOp.getLoc(), postBlock,
yieldOp.getResults()); yieldOp.getResults());
rewriter.eraseOp(yieldOp); rewriter.eraseOp(yieldOp);
} }
} }

View File

@ -13,7 +13,7 @@ add_mlir_dialect_library(MLIRSparseTensorPipelines
MLIRMemRefToLLVM MLIRMemRefToLLVM
MLIRPass MLIRPass
MLIRReconcileUnrealizedCasts MLIRReconcileUnrealizedCasts
MLIRSCFToStandard MLIRSCFToControlFlow
MLIRSparseTensor MLIRSparseTensor
MLIRSparseTensorTransforms MLIRSparseTensorTransforms
MLIRStandardOpsTransforms MLIRStandardOpsTransforms

View File

@ -33,7 +33,7 @@ void mlir::sparse_tensor::buildSparseCompiler(
pm.addNestedPass<FuncOp>(createLinalgBufferizePass()); pm.addNestedPass<FuncOp>(createLinalgBufferizePass());
pm.addNestedPass<FuncOp>(createConvertLinalgToLoopsPass()); pm.addNestedPass<FuncOp>(createConvertLinalgToLoopsPass());
pm.addNestedPass<FuncOp>(createConvertVectorToSCFPass()); pm.addNestedPass<FuncOp>(createConvertVectorToSCFPass());
pm.addPass(createLowerToCFGPass()); // --convert-scf-to-std pm.addNestedPass<FuncOp>(createConvertSCFToCFPass());
pm.addPass(createFuncBufferizePass()); pm.addPass(createFuncBufferizePass());
pm.addPass(arith::createConstantBufferizePass()); pm.addPass(arith::createConstantBufferizePass());
pm.addNestedPass<FuncOp>(createTensorBufferizePass()); pm.addNestedPass<FuncOp>(createTensorBufferizePass());

View File

@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRStandard
MLIRArithmetic MLIRArithmetic
MLIRCallInterfaces MLIRCallInterfaces
MLIRCastInterfaces MLIRCastInterfaces
MLIRControlFlow
MLIRControlFlowInterfaces MLIRControlFlowInterfaces
MLIRInferTypeOpInterface MLIRInferTypeOpInterface
MLIRIR MLIRIR

View File

@ -8,9 +8,8 @@
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
#include "mlir/Dialect/CommonFolders.h" #include "mlir/Dialect/CommonFolders.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h" #include "mlir/IR/AffineMap.h"
#include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/BlockAndValueMapping.h"
@ -77,7 +76,7 @@ struct StdInlinerInterface : public DialectInlinerInterface {
// Replace the return with a branch to the dest. // Replace the return with a branch to the dest.
OpBuilder builder(op); OpBuilder builder(op);
builder.create<BranchOp>(op->getLoc(), newDest, returnOp.getOperands()); builder.create<cf::BranchOp>(op->getLoc(), newDest, returnOp.getOperands());
op->erase(); op->erase();
} }
@ -121,130 +120,6 @@ Operation *StandardOpsDialect::materializeConstant(OpBuilder &builder,
return nullptr; return nullptr;
} }
//===----------------------------------------------------------------------===//
// AssertOp
//===----------------------------------------------------------------------===//
LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) {
// Erase assertion if argument is constant true.
if (matchPattern(op.getArg(), m_One())) {
rewriter.eraseOp(op);
return success();
}
return failure();
}
//===----------------------------------------------------------------------===//
// BranchOp
//===----------------------------------------------------------------------===//
/// Given a successor, try to collapse it to a new destination if it only
/// contains a passthrough unconditional branch. If the successor is
/// collapsable, `successor` and `successorOperands` are updated to reference
/// the new destination and values. `argStorage` is used as storage if operands
/// to the collapsed successor need to be remapped. It must outlive uses of
/// successorOperands.
static LogicalResult collapseBranch(Block *&successor,
ValueRange &successorOperands,
SmallVectorImpl<Value> &argStorage) {
// Check that the successor only contains a unconditional branch.
if (std::next(successor->begin()) != successor->end())
return failure();
// Check that the terminator is an unconditional branch.
BranchOp successorBranch = dyn_cast<BranchOp>(successor->getTerminator());
if (!successorBranch)
return failure();
// Check that the arguments are only used within the terminator.
for (BlockArgument arg : successor->getArguments()) {
for (Operation *user : arg.getUsers())
if (user != successorBranch)
return failure();
}
// Don't try to collapse branches to infinite loops.
Block *successorDest = successorBranch.getDest();
if (successorDest == successor)
return failure();
// Update the operands to the successor. If the branch parent has no
// arguments, we can use the branch operands directly.
OperandRange operands = successorBranch.getOperands();
if (successor->args_empty()) {
successor = successorDest;
successorOperands = operands;
return success();
}
// Otherwise, we need to remap any argument operands.
for (Value operand : operands) {
BlockArgument argOperand = operand.dyn_cast<BlockArgument>();
if (argOperand && argOperand.getOwner() == successor)
argStorage.push_back(successorOperands[argOperand.getArgNumber()]);
else
argStorage.push_back(operand);
}
successor = successorDest;
successorOperands = argStorage;
return success();
}
/// Simplify a branch to a block that has a single predecessor. This effectively
/// merges the two blocks.
static LogicalResult
simplifyBrToBlockWithSinglePred(BranchOp op, PatternRewriter &rewriter) {
// Check that the successor block has a single predecessor.
Block *succ = op.getDest();
Block *opParent = op->getBlock();
if (succ == opParent || !llvm::hasSingleElement(succ->getPredecessors()))
return failure();
// Merge the successor into the current block and erase the branch.
rewriter.mergeBlocks(succ, opParent, op.getOperands());
rewriter.eraseOp(op);
return success();
}
/// br ^bb1
/// ^bb1
/// br ^bbN(...)
///
/// -> br ^bbN(...)
///
static LogicalResult simplifyPassThroughBr(BranchOp op,
PatternRewriter &rewriter) {
Block *dest = op.getDest();
ValueRange destOperands = op.getOperands();
SmallVector<Value, 4> destOperandStorage;
// Try to collapse the successor if it points somewhere other than this
// block.
if (dest == op->getBlock() ||
failed(collapseBranch(dest, destOperands, destOperandStorage)))
return failure();
// Create a new branch with the collapsed successor.
rewriter.replaceOpWithNewOp<BranchOp>(op, dest, destOperands);
return success();
}
LogicalResult BranchOp::canonicalize(BranchOp op, PatternRewriter &rewriter) {
return success(succeeded(simplifyBrToBlockWithSinglePred(op, rewriter)) ||
succeeded(simplifyPassThroughBr(op, rewriter)));
}
void BranchOp::setDest(Block *block) { return setSuccessor(block); }
void BranchOp::eraseOperand(unsigned index) { (*this)->eraseOperand(index); }
Optional<MutableOperandRange>
BranchOp::getMutableSuccessorOperands(unsigned index) {
assert(index == 0 && "invalid successor index");
return getDestOperandsMutable();
}
Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) {
return getDest();
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// CallOp // CallOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -307,260 +182,6 @@ LogicalResult CallIndirectOp::canonicalize(CallIndirectOp indirectCall,
return success(); return success();
} }
//===----------------------------------------------------------------------===//
// CondBranchOp
//===----------------------------------------------------------------------===//
namespace {
/// cond_br true, ^bb1, ^bb2
/// -> br ^bb1
/// cond_br false, ^bb1, ^bb2
/// -> br ^bb2
///
struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
using OpRewritePattern<CondBranchOp>::OpRewritePattern;
LogicalResult matchAndRewrite(CondBranchOp condbr,
PatternRewriter &rewriter) const override {
if (matchPattern(condbr.getCondition(), m_NonZero())) {
// True branch taken.
rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
condbr.getTrueOperands());
return success();
}
if (matchPattern(condbr.getCondition(), m_Zero())) {
// False branch taken.
rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
condbr.getFalseOperands());
return success();
}
return failure();
}
};
/// cond_br %cond, ^bb1, ^bb2
/// ^bb1
/// br ^bbN(...)
/// ^bb2
/// br ^bbK(...)
///
/// -> cond_br %cond, ^bbN(...), ^bbK(...)
///
struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> {
using OpRewritePattern<CondBranchOp>::OpRewritePattern;
LogicalResult matchAndRewrite(CondBranchOp condbr,
PatternRewriter &rewriter) const override {
Block *trueDest = condbr.getTrueDest(), *falseDest = condbr.getFalseDest();
ValueRange trueDestOperands = condbr.getTrueOperands();
ValueRange falseDestOperands = condbr.getFalseOperands();
SmallVector<Value, 4> trueDestOperandStorage, falseDestOperandStorage;
// Try to collapse one of the current successors.
LogicalResult collapsedTrue =
collapseBranch(trueDest, trueDestOperands, trueDestOperandStorage);
LogicalResult collapsedFalse =
collapseBranch(falseDest, falseDestOperands, falseDestOperandStorage);
if (failed(collapsedTrue) && failed(collapsedFalse))
return failure();
// Create a new branch with the collapsed successors.
rewriter.replaceOpWithNewOp<CondBranchOp>(condbr, condbr.getCondition(),
trueDest, trueDestOperands,
falseDest, falseDestOperands);
return success();
}
};
/// cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N)
/// -> br ^bb1(A, ..., N)
///
/// cond_br %cond, ^bb1(A), ^bb1(B)
/// -> %select = arith.select %cond, A, B
/// br ^bb1(%select)
///
struct SimplifyCondBranchIdenticalSuccessors
: public OpRewritePattern<CondBranchOp> {
using OpRewritePattern<CondBranchOp>::OpRewritePattern;
LogicalResult matchAndRewrite(CondBranchOp condbr,
PatternRewriter &rewriter) const override {
// Check that the true and false destinations are the same and have the same
// operands.
Block *trueDest = condbr.getTrueDest();
if (trueDest != condbr.getFalseDest())
return failure();
// If all of the operands match, no selects need to be generated.
OperandRange trueOperands = condbr.getTrueOperands();
OperandRange falseOperands = condbr.getFalseOperands();
if (trueOperands == falseOperands) {
rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, trueOperands);
return success();
}
// Otherwise, if the current block is the only predecessor insert selects
// for any mismatched branch operands.
if (trueDest->getUniquePredecessor() != condbr->getBlock())
return failure();
// Generate a select for any operands that differ between the two.
SmallVector<Value, 8> mergedOperands;
mergedOperands.reserve(trueOperands.size());
Value condition = condbr.getCondition();
for (auto it : llvm::zip(trueOperands, falseOperands)) {
if (std::get<0>(it) == std::get<1>(it))
mergedOperands.push_back(std::get<0>(it));
else
mergedOperands.push_back(rewriter.create<arith::SelectOp>(
condbr.getLoc(), condition, std::get<0>(it), std::get<1>(it)));
}
rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, mergedOperands);
return success();
}
};
/// ...
/// cond_br %cond, ^bb1(...), ^bb2(...)
/// ...
/// ^bb1: // has single predecessor
/// ...
/// cond_br %cond, ^bb3(...), ^bb4(...)
///
/// ->
///
/// ...
/// cond_br %cond, ^bb1(...), ^bb2(...)
/// ...
/// ^bb1: // has single predecessor
/// ...
/// br ^bb3(...)
///
struct SimplifyCondBranchFromCondBranchOnSameCondition
: public OpRewritePattern<CondBranchOp> {
using OpRewritePattern<CondBranchOp>::OpRewritePattern;
LogicalResult matchAndRewrite(CondBranchOp condbr,
PatternRewriter &rewriter) const override {
// Check that we have a single distinct predecessor.
Block *currentBlock = condbr->getBlock();
Block *predecessor = currentBlock->getSinglePredecessor();
if (!predecessor)
return failure();
// Check that the predecessor terminates with a conditional branch to this
// block and that it branches on the same condition.
auto predBranch = dyn_cast<CondBranchOp>(predecessor->getTerminator());
if (!predBranch || condbr.getCondition() != predBranch.getCondition())
return failure();
// Fold this branch to an unconditional branch.
if (currentBlock == predBranch.getTrueDest())
rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
condbr.getTrueDestOperands());
else
rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
condbr.getFalseDestOperands());
return success();
}
};
/// cond_br %arg0, ^trueB, ^falseB
///
/// ^trueB:
/// "test.consumer1"(%arg0) : (i1) -> ()
/// ...
///
/// ^falseB:
/// "test.consumer2"(%arg0) : (i1) -> ()
/// ...
///
/// ->
///
/// cond_br %arg0, ^trueB, ^falseB
/// ^trueB:
/// "test.consumer1"(%true) : (i1) -> ()
/// ...
///
/// ^falseB:
/// "test.consumer2"(%false) : (i1) -> ()
/// ...
struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> {
using OpRewritePattern<CondBranchOp>::OpRewritePattern;
LogicalResult matchAndRewrite(CondBranchOp condbr,
PatternRewriter &rewriter) const override {
// Check that we have a single distinct predecessor.
bool replaced = false;
Type ty = rewriter.getI1Type();
// These variables serve to prevent creating duplicate constants
// and hold constant true or false values.
Value constantTrue = nullptr;
Value constantFalse = nullptr;
// TODO These checks can be expanded to encompas any use with only
// either the true of false edge as a predecessor. For now, we fall
// back to checking the single predecessor is given by the true/fasle
// destination, thereby ensuring that only that edge can reach the
// op.
if (condbr.getTrueDest()->getSinglePredecessor()) {
for (OpOperand &use :
llvm::make_early_inc_range(condbr.getCondition().getUses())) {
if (use.getOwner()->getBlock() == condbr.getTrueDest()) {
replaced = true;
if (!constantTrue)
constantTrue = rewriter.create<arith::ConstantOp>(
condbr.getLoc(), ty, rewriter.getBoolAttr(true));
rewriter.updateRootInPlace(use.getOwner(),
[&] { use.set(constantTrue); });
}
}
}
if (condbr.getFalseDest()->getSinglePredecessor()) {
for (OpOperand &use :
llvm::make_early_inc_range(condbr.getCondition().getUses())) {
if (use.getOwner()->getBlock() == condbr.getFalseDest()) {
replaced = true;
if (!constantFalse)
constantFalse = rewriter.create<arith::ConstantOp>(
condbr.getLoc(), ty, rewriter.getBoolAttr(false));
rewriter.updateRootInPlace(use.getOwner(),
[&] { use.set(constantFalse); });
}
}
}
return success(replaced);
}
};
} // namespace
void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
SimplifyCondBranchIdenticalSuccessors,
SimplifyCondBranchFromCondBranchOnSameCondition,
CondBranchTruthPropagation>(context);
}
Optional<MutableOperandRange>
CondBranchOp::getMutableSuccessorOperands(unsigned index) {
assert(index < getNumSuccessors() && "invalid successor index");
return index == trueIndex ? getTrueDestOperandsMutable()
: getFalseDestOperandsMutable();
}
Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
if (IntegerAttr condAttr = operands.front().dyn_cast_or_null<IntegerAttr>())
return condAttr.getValue().isOneValue() ? getTrueDest() : getFalseDest();
return nullptr;
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// ConstantOp // ConstantOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -621,439 +242,6 @@ LogicalResult ReturnOp::verify() {
return success(); return success();
} }
//===----------------------------------------------------------------------===//
// SwitchOp
//===----------------------------------------------------------------------===//
void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
Block *defaultDestination, ValueRange defaultOperands,
DenseIntElementsAttr caseValues,
BlockRange caseDestinations,
ArrayRef<ValueRange> caseOperands) {
build(builder, result, value, defaultOperands, caseOperands, caseValues,
defaultDestination, caseDestinations);
}
void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
Block *defaultDestination, ValueRange defaultOperands,
ArrayRef<APInt> caseValues, BlockRange caseDestinations,
ArrayRef<ValueRange> caseOperands) {
DenseIntElementsAttr caseValuesAttr;
if (!caseValues.empty()) {
ShapedType caseValueType = VectorType::get(
static_cast<int64_t>(caseValues.size()), value.getType());
caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
}
build(builder, result, value, defaultDestination, defaultOperands,
caseValuesAttr, caseDestinations, caseOperands);
}
/// <cases> ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)?
/// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )*
static ParseResult parseSwitchOpCases(
OpAsmParser &parser, Type &flagType, Block *&defaultDestination,
SmallVectorImpl<OpAsmParser::OperandType> &defaultOperands,
SmallVectorImpl<Type> &defaultOperandTypes,
DenseIntElementsAttr &caseValues,
SmallVectorImpl<Block *> &caseDestinations,
SmallVectorImpl<SmallVector<OpAsmParser::OperandType>> &caseOperands,
SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) {
if (parser.parseKeyword("default") || parser.parseColon() ||
parser.parseSuccessor(defaultDestination))
return failure();
if (succeeded(parser.parseOptionalLParen())) {
if (parser.parseRegionArgumentList(defaultOperands) ||
parser.parseColonTypeList(defaultOperandTypes) || parser.parseRParen())
return failure();
}
SmallVector<APInt> values;
unsigned bitWidth = flagType.getIntOrFloatBitWidth();
while (succeeded(parser.parseOptionalComma())) {
int64_t value = 0;
if (failed(parser.parseInteger(value)))
return failure();
values.push_back(APInt(bitWidth, value));
Block *destination;
SmallVector<OpAsmParser::OperandType> operands;
SmallVector<Type> operandTypes;
if (failed(parser.parseColon()) ||
failed(parser.parseSuccessor(destination)))
return failure();
if (succeeded(parser.parseOptionalLParen())) {
if (failed(parser.parseRegionArgumentList(operands)) ||
failed(parser.parseColonTypeList(operandTypes)) ||
failed(parser.parseRParen()))
return failure();
}
caseDestinations.push_back(destination);
caseOperands.emplace_back(operands);
caseOperandTypes.emplace_back(operandTypes);
}
if (!values.empty()) {
ShapedType caseValueType =
VectorType::get(static_cast<int64_t>(values.size()), flagType);
caseValues = DenseIntElementsAttr::get(caseValueType, values);
}
return success();
}
static void printSwitchOpCases(
OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination,
OperandRange defaultOperands, TypeRange defaultOperandTypes,
DenseIntElementsAttr caseValues, SuccessorRange caseDestinations,
OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes) {
p << " default: ";
p.printSuccessorAndUseList(defaultDestination, defaultOperands);
if (!caseValues)
return;
for (const auto &it : llvm::enumerate(caseValues.getValues<APInt>())) {
p << ',';
p.printNewline();
p << " ";
p << it.value().getLimitedValue();
p << ": ";
p.printSuccessorAndUseList(caseDestinations[it.index()],
caseOperands[it.index()]);
}
p.printNewline();
}
LogicalResult SwitchOp::verify() {
auto caseValues = getCaseValues();
auto caseDestinations = getCaseDestinations();
if (!caseValues && caseDestinations.empty())
return success();
Type flagType = getFlag().getType();
Type caseValueType = caseValues->getType().getElementType();
if (caseValueType != flagType)
return emitOpError() << "'flag' type (" << flagType
<< ") should match case value type (" << caseValueType
<< ")";
if (caseValues &&
caseValues->size() != static_cast<int64_t>(caseDestinations.size()))
return emitOpError() << "number of case values (" << caseValues->size()
<< ") should match number of "
"case destinations ("
<< caseDestinations.size() << ")";
return success();
}
Optional<MutableOperandRange>
SwitchOp::getMutableSuccessorOperands(unsigned index) {
assert(index < getNumSuccessors() && "invalid successor index");
return index == 0 ? getDefaultOperandsMutable()
: getCaseOperandsMutable(index - 1);
}
Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
Optional<DenseIntElementsAttr> caseValues = getCaseValues();
if (!caseValues)
return getDefaultDestination();
SuccessorRange caseDests = getCaseDestinations();
if (auto value = operands.front().dyn_cast_or_null<IntegerAttr>()) {
for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>()))
if (it.value() == value.getValue())
return caseDests[it.index()];
return getDefaultDestination();
}
return nullptr;
}
/// switch %flag : i32, [
/// default: ^bb1
/// ]
/// -> br ^bb1
static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op,
PatternRewriter &rewriter) {
if (!op.getCaseDestinations().empty())
return failure();
rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
op.getDefaultOperands());
return success();
}
/// switch %flag : i32, [
/// default: ^bb1,
/// 42: ^bb1,
/// 43: ^bb2
/// ]
/// ->
/// switch %flag : i32, [
/// default: ^bb1,
/// 43: ^bb2
/// ]
static LogicalResult
dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter) {
SmallVector<Block *> newCaseDestinations;
SmallVector<ValueRange> newCaseOperands;
SmallVector<APInt> newCaseValues;
bool requiresChange = false;
auto caseValues = op.getCaseValues();
auto caseDests = op.getCaseDestinations();
for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
if (caseDests[it.index()] == op.getDefaultDestination() &&
op.getCaseOperands(it.index()) == op.getDefaultOperands()) {
requiresChange = true;
continue;
}
newCaseDestinations.push_back(caseDests[it.index()]);
newCaseOperands.push_back(op.getCaseOperands(it.index()));
newCaseValues.push_back(it.value());
}
if (!requiresChange)
return failure();
rewriter.replaceOpWithNewOp<SwitchOp>(
op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
newCaseValues, newCaseDestinations, newCaseOperands);
return success();
}
/// Helper for folding a switch with a constant value.
/// switch %c_42 : i32, [
/// default: ^bb1 ,
/// 42: ^bb2,
/// 43: ^bb3
/// ]
/// -> br ^bb2
static void foldSwitch(SwitchOp op, PatternRewriter &rewriter,
const APInt &caseValue) {
auto caseValues = op.getCaseValues();
for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
if (it.value() == caseValue) {
rewriter.replaceOpWithNewOp<BranchOp>(
op, op.getCaseDestinations()[it.index()],
op.getCaseOperands(it.index()));
return;
}
}
rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
op.getDefaultOperands());
}
/// switch %c_42 : i32, [
/// default: ^bb1,
/// 42: ^bb2,
/// 43: ^bb3
/// ]
/// -> br ^bb2
static LogicalResult simplifyConstSwitchValue(SwitchOp op,
PatternRewriter &rewriter) {
APInt caseValue;
if (!matchPattern(op.getFlag(), m_ConstantInt(&caseValue)))
return failure();
foldSwitch(op, rewriter, caseValue);
return success();
}
/// switch %c_42 : i32, [
/// default: ^bb1,
/// 42: ^bb2,
/// ]
/// ^bb2:
/// br ^bb3
/// ->
/// switch %c_42 : i32, [
/// default: ^bb1,
/// 42: ^bb3,
/// ]
static LogicalResult simplifyPassThroughSwitch(SwitchOp op,
PatternRewriter &rewriter) {
SmallVector<Block *> newCaseDests;
SmallVector<ValueRange> newCaseOperands;
SmallVector<SmallVector<Value>> argStorage;
auto caseValues = op.getCaseValues();
auto caseDests = op.getCaseDestinations();
bool requiresChange = false;
for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
Block *caseDest = caseDests[i];
ValueRange caseOperands = op.getCaseOperands(i);
argStorage.emplace_back();
if (succeeded(collapseBranch(caseDest, caseOperands, argStorage.back())))
requiresChange = true;
newCaseDests.push_back(caseDest);
newCaseOperands.push_back(caseOperands);
}
Block *defaultDest = op.getDefaultDestination();
ValueRange defaultOperands = op.getDefaultOperands();
argStorage.emplace_back();
if (succeeded(
collapseBranch(defaultDest, defaultOperands, argStorage.back())))
requiresChange = true;
if (!requiresChange)
return failure();
rewriter.replaceOpWithNewOp<SwitchOp>(op, op.getFlag(), defaultDest,
defaultOperands, caseValues.getValue(),
newCaseDests, newCaseOperands);
return success();
}
/// switch %flag : i32, [
/// default: ^bb1,
/// 42: ^bb2,
/// ]
/// ^bb2:
/// switch %flag : i32, [
/// default: ^bb3,
/// 42: ^bb4
/// ]
/// ->
/// switch %flag : i32, [
/// default: ^bb1,
/// 42: ^bb2,
/// ]
/// ^bb2:
/// br ^bb4
///
/// and
///
/// switch %flag : i32, [
/// default: ^bb1,
/// 42: ^bb2,
/// ]
/// ^bb2:
/// switch %flag : i32, [
/// default: ^bb3,
/// 43: ^bb4
/// ]
/// ->
/// switch %flag : i32, [
/// default: ^bb1,
/// 42: ^bb2,
/// ]
/// ^bb2:
/// br ^bb3
static LogicalResult
simplifySwitchFromSwitchOnSameCondition(SwitchOp op,
PatternRewriter &rewriter) {
// Check that we have a single distinct predecessor.
Block *currentBlock = op->getBlock();
Block *predecessor = currentBlock->getSinglePredecessor();
if (!predecessor)
return failure();
// Check that the predecessor terminates with a switch branch to this block
// and that it branches on the same condition and that this branch isn't the
// default destination.
auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
predSwitch.getDefaultDestination() == currentBlock)
return failure();
// Fold this switch to an unconditional branch.
SuccessorRange predDests = predSwitch.getCaseDestinations();
auto it = llvm::find(predDests, currentBlock);
if (it != predDests.end()) {
Optional<DenseIntElementsAttr> predCaseValues = predSwitch.getCaseValues();
foldSwitch(op, rewriter,
predCaseValues->getValues<APInt>()[it - predDests.begin()]);
} else {
rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
op.getDefaultOperands());
}
return success();
}
/// switch %flag : i32, [
/// default: ^bb1,
/// 42: ^bb2
/// ]
/// ^bb1:
/// switch %flag : i32, [
/// default: ^bb3,
/// 42: ^bb4,
/// 43: ^bb5
/// ]
/// ->
/// switch %flag : i32, [
/// default: ^bb1,
/// 42: ^bb2,
/// ]
/// ^bb1:
/// switch %flag : i32, [
/// default: ^bb3,
/// 43: ^bb5
/// ]
static LogicalResult
simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op,
PatternRewriter &rewriter) {
// Check that we have a single distinct predecessor.
Block *currentBlock = op->getBlock();
Block *predecessor = currentBlock->getSinglePredecessor();
if (!predecessor)
return failure();
// Check that the predecessor terminates with a switch branch to this block
// and that it branches on the same condition and that this branch is the
// default destination.
auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
predSwitch.getDefaultDestination() != currentBlock)
return failure();
// Delete case values that are not possible here.
DenseSet<APInt> caseValuesToRemove;
auto predDests = predSwitch.getCaseDestinations();
auto predCaseValues = predSwitch.getCaseValues();
for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i)
if (currentBlock != predDests[i])
caseValuesToRemove.insert(predCaseValues->getValues<APInt>()[i]);
SmallVector<Block *> newCaseDestinations;
SmallVector<ValueRange> newCaseOperands;
SmallVector<APInt> newCaseValues;
bool requiresChange = false;
auto caseValues = op.getCaseValues();
auto caseDests = op.getCaseDestinations();
for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
if (caseValuesToRemove.contains(it.value())) {
requiresChange = true;
continue;
}
newCaseDestinations.push_back(caseDests[it.index()]);
newCaseOperands.push_back(op.getCaseOperands(it.index()));
newCaseValues.push_back(it.value());
}
if (!requiresChange)
return failure();
rewriter.replaceOpWithNewOp<SwitchOp>(
op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
newCaseValues, newCaseDestinations, newCaseOperands);
return success();
}
void SwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add(&simplifySwitchWithOnlyDefault)
.add(&dropSwitchCasesThatMatchDefault)
.add(&simplifyConstSwitchValue)
.add(&simplifyPassThroughSwitch)
.add(&simplifySwitchFromSwitchOnSameCondition)
.add(&simplifySwitchFromDefaultSwitchOnSameCondition);
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// TableGen'd op method definitions // TableGen'd op method definitions
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/SCF/SCF.h"
@ -41,6 +42,7 @@ void registerToCppTranslation() {
[](DialectRegistry &registry) { [](DialectRegistry &registry) {
// clang-format off // clang-format off
registry.insert<arith::ArithmeticDialect, registry.insert<arith::ArithmeticDialect,
cf::ControlFlowDialect,
emitc::EmitCDialect, emitc::EmitCDialect,
math::MathDialect, math::MathDialect,
StandardOpsDialect, StandardOpsDialect,

View File

@ -6,8 +6,7 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include <utility> #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
@ -23,6 +22,7 @@
#include "llvm/ADT/TypeSwitch.h" #include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h" #include "llvm/Support/Debug.h"
#include "llvm/Support/FormatVariadic.h" #include "llvm/Support/FormatVariadic.h"
#include <utility>
#define DEBUG_TYPE "translate-to-cpp" #define DEBUG_TYPE "translate-to-cpp"
@ -237,7 +237,8 @@ static LogicalResult printOperation(CppEmitter &emitter,
return printConstantOp(emitter, operation, value); return printConstantOp(emitter, operation, value);
} }
static LogicalResult printOperation(CppEmitter &emitter, BranchOp branchOp) { static LogicalResult printOperation(CppEmitter &emitter,
cf::BranchOp branchOp) {
raw_ostream &os = emitter.ostream(); raw_ostream &os = emitter.ostream();
Block &successor = *branchOp.getSuccessor(); Block &successor = *branchOp.getSuccessor();
@ -257,7 +258,7 @@ static LogicalResult printOperation(CppEmitter &emitter, BranchOp branchOp) {
} }
static LogicalResult printOperation(CppEmitter &emitter, static LogicalResult printOperation(CppEmitter &emitter,
CondBranchOp condBranchOp) { cf::CondBranchOp condBranchOp) {
raw_indented_ostream &os = emitter.ostream(); raw_indented_ostream &os = emitter.ostream();
Block &trueSuccessor = *condBranchOp.getTrueDest(); Block &trueSuccessor = *condBranchOp.getTrueDest();
Block &falseSuccessor = *condBranchOp.getFalseDest(); Block &falseSuccessor = *condBranchOp.getFalseDest();
@ -637,11 +638,12 @@ static LogicalResult printOperation(CppEmitter &emitter, FuncOp functionOp) {
return failure(); return failure();
} }
for (Operation &op : block.getOperations()) { for (Operation &op : block.getOperations()) {
// When generating code for an scf.if or std.cond_br op no semicolon needs // When generating code for an scf.if or cf.cond_br op no semicolon needs
// to be printed after the closing brace. // to be printed after the closing brace.
// When generating code for an scf.for op, printing a trailing semicolon // When generating code for an scf.for op, printing a trailing semicolon
// is handled within the printOperation function. // is handled within the printOperation function.
bool trailingSemicolon = !isa<scf::IfOp, scf::ForOp, CondBranchOp>(op); bool trailingSemicolon =
!isa<scf::IfOp, scf::ForOp, cf::CondBranchOp>(op);
if (failed(emitter.emitOperation( if (failed(emitter.emitOperation(
op, /*trailingSemicolon=*/trailingSemicolon))) op, /*trailingSemicolon=*/trailingSemicolon)))
@ -907,8 +909,8 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
.Case<scf::ForOp, scf::IfOp, scf::YieldOp>( .Case<scf::ForOp, scf::IfOp, scf::YieldOp>(
[&](auto op) { return printOperation(*this, op); }) [&](auto op) { return printOperation(*this, op); })
// Standard ops. // Standard ops.
.Case<BranchOp, mlir::CallOp, CondBranchOp, mlir::ConstantOp, FuncOp, .Case<cf::BranchOp, mlir::CallOp, cf::CondBranchOp, mlir::ConstantOp,
ModuleOp, ReturnOp>( FuncOp, ModuleOp, ReturnOp>(
[&](auto op) { return printOperation(*this, op); }) [&](auto op) { return printOperation(*this, op); })
// Arithmetic ops. // Arithmetic ops.
.Case<arith::ConstantOp>( .Case<arith::ConstantOp>(

View File

@ -52,10 +52,10 @@ func @control_flow(%arg: memref<2xf32>, %cond: i1) attributes {test.ptr = "func"
%1 = memref.alloca() {test.ptr = "alloca_2"} : memref<8x64xf32> %1 = memref.alloca() {test.ptr = "alloca_2"} : memref<8x64xf32>
%2 = memref.alloc() {test.ptr = "alloc_1"} : memref<8x64xf32> %2 = memref.alloc() {test.ptr = "alloc_1"} : memref<8x64xf32>
cond_br %cond, ^bb1(%0 : memref<8x64xf32>), ^bb2(%0 : memref<8x64xf32>) cf.cond_br %cond, ^bb1(%0 : memref<8x64xf32>), ^bb2(%0 : memref<8x64xf32>)
^bb1(%arg1: memref<8x64xf32>): ^bb1(%arg1: memref<8x64xf32>):
br ^bb2(%arg1 : memref<8x64xf32>) cf.br ^bb2(%arg1 : memref<8x64xf32>)
^bb2(%arg2: memref<8x64xf32>): ^bb2(%arg2: memref<8x64xf32>):
return return
@ -85,10 +85,10 @@ func @control_flow_merge(%arg: memref<2xf32>, %cond: i1) attributes {test.ptr =
%1 = memref.alloca() {test.ptr = "alloca_2"} : memref<8x64xf32> %1 = memref.alloca() {test.ptr = "alloca_2"} : memref<8x64xf32>
%2 = memref.alloc() {test.ptr = "alloc_1"} : memref<8x64xf32> %2 = memref.alloc() {test.ptr = "alloc_1"} : memref<8x64xf32>
cond_br %cond, ^bb1(%0 : memref<8x64xf32>), ^bb2(%2 : memref<8x64xf32>) cf.cond_br %cond, ^bb1(%0 : memref<8x64xf32>), ^bb2(%2 : memref<8x64xf32>)
^bb1(%arg1: memref<8x64xf32>): ^bb1(%arg1: memref<8x64xf32>):
br ^bb2(%arg1 : memref<8x64xf32>) cf.br ^bb2(%arg1 : memref<8x64xf32>)
^bb2(%arg2: memref<8x64xf32>): ^bb2(%arg2: memref<8x64xf32>):
return return

View File

@ -2,11 +2,11 @@
// CHECK-LABEL: Testing : func_condBranch // CHECK-LABEL: Testing : func_condBranch
func @func_condBranch(%cond : i1) { func @func_condBranch(%cond : i1) {
cond_br %cond, ^bb1, ^bb2 cf.cond_br %cond, ^bb1, ^bb2
^bb1: ^bb1:
br ^exit cf.br ^exit
^bb2: ^bb2:
br ^exit cf.br ^exit
^exit: ^exit:
return return
} }
@ -49,14 +49,14 @@ func @func_condBranch(%cond : i1) {
// CHECK-LABEL: Testing : func_loop // CHECK-LABEL: Testing : func_loop
func @func_loop(%arg0 : i32, %arg1 : i32) { func @func_loop(%arg0 : i32, %arg1 : i32) {
br ^loopHeader(%arg0 : i32) cf.br ^loopHeader(%arg0 : i32)
^loopHeader(%counter : i32): ^loopHeader(%counter : i32):
%lessThan = arith.cmpi slt, %counter, %arg1 : i32 %lessThan = arith.cmpi slt, %counter, %arg1 : i32
cond_br %lessThan, ^loopBody, ^exit cf.cond_br %lessThan, ^loopBody, ^exit
^loopBody: ^loopBody:
%const0 = arith.constant 1 : i32 %const0 = arith.constant 1 : i32
%inc = arith.addi %counter, %const0 : i32 %inc = arith.addi %counter, %const0 : i32
br ^loopHeader(%inc : i32) cf.br ^loopHeader(%inc : i32)
^exit: ^exit:
return return
} }
@ -153,17 +153,17 @@ func @func_loop_nested_region(
%arg2 : index, %arg2 : index,
%arg3 : index, %arg3 : index,
%arg4 : index) { %arg4 : index) {
br ^loopHeader(%arg0 : i32) cf.br ^loopHeader(%arg0 : i32)
^loopHeader(%counter : i32): ^loopHeader(%counter : i32):
%lessThan = arith.cmpi slt, %counter, %arg1 : i32 %lessThan = arith.cmpi slt, %counter, %arg1 : i32
cond_br %lessThan, ^loopBody, ^exit cf.cond_br %lessThan, ^loopBody, ^exit
^loopBody: ^loopBody:
%const0 = arith.constant 1 : i32 %const0 = arith.constant 1 : i32
%inc = arith.addi %counter, %const0 : i32 %inc = arith.addi %counter, %const0 : i32
scf.for %arg5 = %arg2 to %arg3 step %arg4 { scf.for %arg5 = %arg2 to %arg3 step %arg4 {
scf.for %arg6 = %arg2 to %arg3 step %arg4 { } scf.for %arg6 = %arg2 to %arg3 step %arg4 { }
} }
br ^loopHeader(%inc : i32) cf.br ^loopHeader(%inc : i32)
^exit: ^exit:
return return
} }

View File

@ -19,7 +19,7 @@ func @func_simpleBranch(%arg0: i32, %arg1 : i32) -> i32 {
// CHECK-NEXT: LiveOut: arg0@0 arg1@0 // CHECK-NEXT: LiveOut: arg0@0 arg1@0
// CHECK-NEXT: BeginLiveness // CHECK-NEXT: BeginLiveness
// CHECK-NEXT: EndLiveness // CHECK-NEXT: EndLiveness
br ^exit cf.br ^exit
^exit: ^exit:
// CHECK: Block: 1 // CHECK: Block: 1
// CHECK-NEXT: LiveIn: arg0@0 arg1@0 // CHECK-NEXT: LiveIn: arg0@0 arg1@0
@ -42,17 +42,17 @@ func @func_condBranch(%cond : i1, %arg1: i32, %arg2 : i32) -> i32 {
// CHECK-NEXT: LiveOut: arg1@0 arg2@0 // CHECK-NEXT: LiveOut: arg1@0 arg2@0
// CHECK-NEXT: BeginLiveness // CHECK-NEXT: BeginLiveness
// CHECK-NEXT: EndLiveness // CHECK-NEXT: EndLiveness
cond_br %cond, ^bb1, ^bb2 cf.cond_br %cond, ^bb1, ^bb2
^bb1: ^bb1:
// CHECK: Block: 1 // CHECK: Block: 1
// CHECK-NEXT: LiveIn: arg1@0 arg2@0 // CHECK-NEXT: LiveIn: arg1@0 arg2@0
// CHECK-NEXT: LiveOut: arg1@0 arg2@0 // CHECK-NEXT: LiveOut: arg1@0 arg2@0
br ^exit cf.br ^exit
^bb2: ^bb2:
// CHECK: Block: 2 // CHECK: Block: 2
// CHECK-NEXT: LiveIn: arg1@0 arg2@0 // CHECK-NEXT: LiveIn: arg1@0 arg2@0
// CHECK-NEXT: LiveOut: arg1@0 arg2@0 // CHECK-NEXT: LiveOut: arg1@0 arg2@0
br ^exit cf.br ^exit
^exit: ^exit:
// CHECK: Block: 3 // CHECK: Block: 3
// CHECK-NEXT: LiveIn: arg1@0 arg2@0 // CHECK-NEXT: LiveIn: arg1@0 arg2@0
@ -74,7 +74,7 @@ func @func_loop(%arg0 : i32, %arg1 : i32) -> i32 {
// CHECK-NEXT: LiveIn:{{ *$}} // CHECK-NEXT: LiveIn:{{ *$}}
// CHECK-NEXT: LiveOut: arg1@0 // CHECK-NEXT: LiveOut: arg1@0
%const0 = arith.constant 0 : i32 %const0 = arith.constant 0 : i32
br ^loopHeader(%const0, %arg0 : i32, i32) cf.br ^loopHeader(%const0, %arg0 : i32, i32)
^loopHeader(%counter : i32, %i : i32): ^loopHeader(%counter : i32, %i : i32):
// CHECK: Block: 1 // CHECK: Block: 1
// CHECK-NEXT: LiveIn: arg1@0 // CHECK-NEXT: LiveIn: arg1@0
@ -82,10 +82,10 @@ func @func_loop(%arg0 : i32, %arg1 : i32) -> i32 {
// CHECK-NEXT: BeginLiveness // CHECK-NEXT: BeginLiveness
// CHECK-NEXT: val_5 // CHECK-NEXT: val_5
// CHECK-NEXT: %2 = arith.cmpi // CHECK-NEXT: %2 = arith.cmpi
// CHECK-NEXT: cond_br // CHECK-NEXT: cf.cond_br
// CHECK-NEXT: EndLiveness // CHECK-NEXT: EndLiveness
%lessThan = arith.cmpi slt, %counter, %arg1 : i32 %lessThan = arith.cmpi slt, %counter, %arg1 : i32
cond_br %lessThan, ^loopBody(%i : i32), ^exit(%i : i32) cf.cond_br %lessThan, ^loopBody(%i : i32), ^exit(%i : i32)
^loopBody(%val : i32): ^loopBody(%val : i32):
// CHECK: Block: 2 // CHECK: Block: 2
// CHECK-NEXT: LiveIn: arg1@0 arg0@1 // CHECK-NEXT: LiveIn: arg1@0 arg0@1
@ -98,12 +98,12 @@ func @func_loop(%arg0 : i32, %arg1 : i32) -> i32 {
// CHECK-NEXT: val_8 // CHECK-NEXT: val_8
// CHECK-NEXT: %4 = arith.addi // CHECK-NEXT: %4 = arith.addi
// CHECK-NEXT: %5 = arith.addi // CHECK-NEXT: %5 = arith.addi
// CHECK-NEXT: br // CHECK-NEXT: cf.br
// CHECK: EndLiveness // CHECK: EndLiveness
%const1 = arith.constant 1 : i32 %const1 = arith.constant 1 : i32
%inc = arith.addi %val, %const1 : i32 %inc = arith.addi %val, %const1 : i32
%inc2 = arith.addi %counter, %const1 : i32 %inc2 = arith.addi %counter, %const1 : i32
br ^loopHeader(%inc, %inc2 : i32, i32) cf.br ^loopHeader(%inc, %inc2 : i32, i32)
^exit(%sum : i32): ^exit(%sum : i32):
// CHECK: Block: 3 // CHECK: Block: 3
// CHECK-NEXT: LiveIn: arg1@0 // CHECK-NEXT: LiveIn: arg1@0
@ -147,14 +147,14 @@ func @func_ranges(%cond : i1, %arg1 : i32, %arg2 : i32, %arg3 : i32) -> i32 {
// CHECK-NEXT: val_9 // CHECK-NEXT: val_9
// CHECK-NEXT: %4 = arith.muli // CHECK-NEXT: %4 = arith.muli
// CHECK-NEXT: %5 = arith.addi // CHECK-NEXT: %5 = arith.addi
// CHECK-NEXT: cond_br // CHECK-NEXT: cf.cond_br
// CHECK-NEXT: %c // CHECK-NEXT: %c
// CHECK-NEXT: %6 = arith.muli // CHECK-NEXT: %6 = arith.muli
// CHECK-NEXT: %7 = arith.muli // CHECK-NEXT: %7 = arith.muli
// CHECK-NEXT: %8 = arith.addi // CHECK-NEXT: %8 = arith.addi
// CHECK-NEXT: val_10 // CHECK-NEXT: val_10
// CHECK-NEXT: %5 = arith.addi // CHECK-NEXT: %5 = arith.addi
// CHECK-NEXT: cond_br // CHECK-NEXT: cf.cond_br
// CHECK-NEXT: %7 // CHECK-NEXT: %7
// CHECK: EndLiveness // CHECK: EndLiveness
%0 = arith.addi %arg1, %arg2 : i32 %0 = arith.addi %arg1, %arg2 : i32
@ -164,7 +164,7 @@ func @func_ranges(%cond : i1, %arg1 : i32, %arg2 : i32, %arg3 : i32) -> i32 {
%3 = arith.muli %0, %1 : i32 %3 = arith.muli %0, %1 : i32
%4 = arith.muli %3, %2 : i32 %4 = arith.muli %3, %2 : i32
%5 = arith.addi %4, %const1 : i32 %5 = arith.addi %4, %const1 : i32
cond_br %cond, ^bb1, ^bb2 cf.cond_br %cond, ^bb1, ^bb2
^bb1: ^bb1:
// CHECK: Block: 1 // CHECK: Block: 1
@ -172,7 +172,7 @@ func @func_ranges(%cond : i1, %arg1 : i32, %arg2 : i32, %arg3 : i32) -> i32 {
// CHECK-NEXT: LiveOut: arg2@0 // CHECK-NEXT: LiveOut: arg2@0
%const4 = arith.constant 4 : i32 %const4 = arith.constant 4 : i32
%6 = arith.muli %4, %const4 : i32 %6 = arith.muli %4, %const4 : i32
br ^exit(%6 : i32) cf.br ^exit(%6 : i32)
^bb2: ^bb2:
// CHECK: Block: 2 // CHECK: Block: 2
@ -180,7 +180,7 @@ func @func_ranges(%cond : i1, %arg1 : i32, %arg2 : i32, %arg3 : i32) -> i32 {
// CHECK-NEXT: LiveOut: arg2@0 // CHECK-NEXT: LiveOut: arg2@0
%7 = arith.muli %4, %5 : i32 %7 = arith.muli %4, %5 : i32
%8 = arith.addi %4, %arg2 : i32 %8 = arith.addi %4, %arg2 : i32
br ^exit(%8 : i32) cf.br ^exit(%8 : i32)
^exit(%sum : i32): ^exit(%sum : i32):
// CHECK: Block: 3 // CHECK: Block: 3
@ -284,7 +284,7 @@ func @nested_region3(
// CHECK-NEXT: %0 = arith.addi // CHECK-NEXT: %0 = arith.addi
// CHECK-NEXT: %1 = arith.addi // CHECK-NEXT: %1 = arith.addi
// CHECK-NEXT: scf.for // CHECK-NEXT: scf.for
// CHECK: // br ^bb1 // CHECK: // cf.br ^bb1
// CHECK-NEXT: %2 = arith.addi // CHECK-NEXT: %2 = arith.addi
// CHECK-NEXT: scf.for // CHECK-NEXT: scf.for
// CHECK: // %2 = arith.addi // CHECK: // %2 = arith.addi
@ -301,7 +301,7 @@ func @nested_region3(
%2 = arith.addi %0, %arg5 : i32 %2 = arith.addi %0, %arg5 : i32
memref.store %2, %buffer[] : memref<i32> memref.store %2, %buffer[] : memref<i32>
} }
br ^exit cf.br ^exit
^exit: ^exit:
// CHECK: Block: 2 // CHECK: Block: 2

View File

@ -1531,10 +1531,10 @@ int registerOnlyStd() {
fprintf(stderr, "@registration\n"); fprintf(stderr, "@registration\n");
// CHECK-LABEL: @registration // CHECK-LABEL: @registration
// CHECK: std.cond_br is_registered: 1 // CHECK: cf.cond_br is_registered: 1
fprintf(stderr, "std.cond_br is_registered: %d\n", fprintf(stderr, "cf.cond_br is_registered: %d\n",
mlirContextIsRegisteredOperation( mlirContextIsRegisteredOperation(
ctx, mlirStringRefCreateFromCString("std.cond_br"))); ctx, mlirStringRefCreateFromCString("cf.cond_br")));
// CHECK: std.not_existing_op is_registered: 0 // CHECK: std.not_existing_op is_registered: 0
fprintf(stderr, "std.not_existing_op is_registered: %d\n", fprintf(stderr, "std.not_existing_op is_registered: %d\n",

View File

@ -27,7 +27,7 @@ func @execute_no_async_args(%arg0: f32, %arg1: memref<1xf32>) {
// CHECK: %[[IS_ERROR:.*]] = call @mlirAsyncRuntimeIsTokenError(%[[TOKEN]]) // CHECK: %[[IS_ERROR:.*]] = call @mlirAsyncRuntimeIsTokenError(%[[TOKEN]])
// CHECK: %[[TRUE:.*]] = arith.constant true // CHECK: %[[TRUE:.*]] = arith.constant true
// CHECK: %[[NOT_ERROR:.*]] = arith.xori %[[IS_ERROR]], %[[TRUE]] : i1 // CHECK: %[[NOT_ERROR:.*]] = arith.xori %[[IS_ERROR]], %[[TRUE]] : i1
// CHECK: assert %[[NOT_ERROR]] // CHECK: cf.assert %[[NOT_ERROR]]
// CHECK-NEXT: return // CHECK-NEXT: return
async.await %token : !async.token async.await %token : !async.token
return return
@ -90,7 +90,7 @@ func @nested_async_execute(%arg0: f32, %arg1: f32, %arg2: memref<1xf32>) {
// CHECK: %[[IS_ERROR:.*]] = call @mlirAsyncRuntimeIsTokenError(%[[TOKEN]]) // CHECK: %[[IS_ERROR:.*]] = call @mlirAsyncRuntimeIsTokenError(%[[TOKEN]])
// CHECK: %[[TRUE:.*]] = arith.constant true // CHECK: %[[TRUE:.*]] = arith.constant true
// CHECK: %[[NOT_ERROR:.*]] = arith.xori %[[IS_ERROR]], %[[TRUE]] : i1 // CHECK: %[[NOT_ERROR:.*]] = arith.xori %[[IS_ERROR]], %[[TRUE]] : i1
// CHECK: assert %[[NOT_ERROR]] // CHECK: cf.assert %[[NOT_ERROR]]
async.await %token0 : !async.token async.await %token0 : !async.token
return return
} }

View File

@ -0,0 +1,41 @@
// RUN: mlir-opt -split-input-file -convert-std-to-spirv -verify-diagnostics %s | FileCheck %s
//===----------------------------------------------------------------------===//
// cf.br, cf.cond_br
//===----------------------------------------------------------------------===//
module attributes {
spv.target_env = #spv.target_env<#spv.vce<v1.0, [], []>, {}>
} {
// CHECK-LABEL: func @simple_loop
func @simple_loop(index, index, index) {
^bb0(%begin : index, %end : index, %step : index):
// CHECK-NEXT: spv.Branch ^bb1
cf.br ^bb1
// CHECK-NEXT: ^bb1: // pred: ^bb0
// CHECK-NEXT: spv.Branch ^bb2({{.*}} : i32)
^bb1: // pred: ^bb0
cf.br ^bb2(%begin : index)
// CHECK: ^bb2({{.*}}: i32): // 2 preds: ^bb1, ^bb3
// CHECK-NEXT: {{.*}} = spv.SLessThan {{.*}}, {{.*}} : i32
// CHECK-NEXT: spv.BranchConditional {{.*}}, ^bb3, ^bb4
^bb2(%0: index): // 2 preds: ^bb1, ^bb3
%1 = arith.cmpi slt, %0, %end : index
cf.cond_br %1, ^bb3, ^bb4
// CHECK: ^bb3: // pred: ^bb2
// CHECK-NEXT: {{.*}} = spv.IAdd {{.*}}, {{.*}} : i32
// CHECK-NEXT: spv.Branch ^bb2({{.*}} : i32)
^bb3: // pred: ^bb2
%2 = arith.addi %0, %step : index
cf.br ^bb2(%2 : index)
// CHECK: ^bb4: // pred: ^bb2
^bb4: // pred: ^bb2
return
}
}

View File

@ -168,16 +168,16 @@ gpu.module @test_module {
%c128 = arith.constant 128 : index %c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index %c32 = arith.constant 32 : index
%0 = gpu.subgroup_mma_load_matrix %arg2[%c0, %c0] {leadDimension = 128 : index} : memref<128x128xf16> -> !gpu.mma_matrix<16x16xf16, "COp"> %0 = gpu.subgroup_mma_load_matrix %arg2[%c0, %c0] {leadDimension = 128 : index} : memref<128x128xf16> -> !gpu.mma_matrix<16x16xf16, "COp">
br ^bb1(%c0, %0 : index, !gpu.mma_matrix<16x16xf16, "COp">) cf.br ^bb1(%c0, %0 : index, !gpu.mma_matrix<16x16xf16, "COp">)
^bb1(%1: index, %2: !gpu.mma_matrix<16x16xf16, "COp">): // 2 preds: ^bb0, ^bb2 ^bb1(%1: index, %2: !gpu.mma_matrix<16x16xf16, "COp">): // 2 preds: ^bb0, ^bb2
%3 = arith.cmpi slt, %1, %c128 : index %3 = arith.cmpi slt, %1, %c128 : index
cond_br %3, ^bb2, ^bb3 cf.cond_br %3, ^bb2, ^bb3
^bb2: // pred: ^bb1 ^bb2: // pred: ^bb1
%4 = gpu.subgroup_mma_load_matrix %arg0[%c0, %1] {leadDimension = 128 : index} : memref<128x128xf16> -> !gpu.mma_matrix<16x16xf16, "AOp"> %4 = gpu.subgroup_mma_load_matrix %arg0[%c0, %1] {leadDimension = 128 : index} : memref<128x128xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
%5 = gpu.subgroup_mma_load_matrix %arg1[%1, %c0] {leadDimension = 128 : index} : memref<128x128xf16> -> !gpu.mma_matrix<16x16xf16, "BOp"> %5 = gpu.subgroup_mma_load_matrix %arg1[%1, %c0] {leadDimension = 128 : index} : memref<128x128xf16> -> !gpu.mma_matrix<16x16xf16, "BOp">
%6 = gpu.subgroup_mma_compute %4, %5, %2 : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp"> %6 = gpu.subgroup_mma_compute %4, %5, %2 : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
%7 = arith.addi %1, %c32 : index %7 = arith.addi %1, %c32 : index
br ^bb1(%7, %6 : index, !gpu.mma_matrix<16x16xf16, "COp">) cf.br ^bb1(%7, %6 : index, !gpu.mma_matrix<16x16xf16, "COp">)
^bb3: // pred: ^bb1 ^bb3: // pred: ^bb1
gpu.subgroup_mma_store_matrix %2, %arg2[%c0, %c0] {leadDimension = 128 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<128x128xf16> gpu.subgroup_mma_store_matrix %2, %arg2[%c0, %c0] {leadDimension = 128 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<128x128xf16>
return return

View File

@ -22,17 +22,17 @@ func @branch_loop() {
// CHECK: omp.parallel // CHECK: omp.parallel
omp.parallel { omp.parallel {
// CHECK-NEXT: llvm.br ^[[BB1:.*]](%{{[0-9]+}}, %{{[0-9]+}} : i64, i64 // CHECK-NEXT: llvm.br ^[[BB1:.*]](%{{[0-9]+}}, %{{[0-9]+}} : i64, i64
br ^bb1(%start, %end : index, index) cf.br ^bb1(%start, %end : index, index)
// CHECK-NEXT: ^[[BB1]](%[[ARG1:[0-9]+]]: i64, %[[ARG2:[0-9]+]]: i64):{{.*}} // CHECK-NEXT: ^[[BB1]](%[[ARG1:[0-9]+]]: i64, %[[ARG2:[0-9]+]]: i64):{{.*}}
^bb1(%0: index, %1: index): ^bb1(%0: index, %1: index):
// CHECK-NEXT: %[[CMP:[0-9]+]] = llvm.icmp "slt" %[[ARG1]], %[[ARG2]] : i64 // CHECK-NEXT: %[[CMP:[0-9]+]] = llvm.icmp "slt" %[[ARG1]], %[[ARG2]] : i64
%2 = arith.cmpi slt, %0, %1 : index %2 = arith.cmpi slt, %0, %1 : index
// CHECK-NEXT: llvm.cond_br %[[CMP]], ^[[BB2:.*]](%{{[0-9]+}}, %{{[0-9]+}} : i64, i64), ^[[BB3:.*]] // CHECK-NEXT: llvm.cond_br %[[CMP]], ^[[BB2:.*]](%{{[0-9]+}}, %{{[0-9]+}} : i64, i64), ^[[BB3:.*]]
cond_br %2, ^bb2(%end, %end : index, index), ^bb3 cf.cond_br %2, ^bb2(%end, %end : index, index), ^bb3
// CHECK-NEXT: ^[[BB2]](%[[ARG3:[0-9]+]]: i64, %[[ARG4:[0-9]+]]: i64): // CHECK-NEXT: ^[[BB2]](%[[ARG3:[0-9]+]]: i64, %[[ARG4:[0-9]+]]: i64):
^bb2(%3: index, %4: index): ^bb2(%3: index, %4: index):
// CHECK-NEXT: llvm.br ^[[BB1]](%[[ARG3]], %[[ARG4]] : i64, i64) // CHECK-NEXT: llvm.br ^[[BB1]](%[[ARG3]], %[[ARG4]] : i64, i64)
br ^bb1(%3, %4 : index, index) cf.br ^bb1(%3, %4 : index, index)
// CHECK-NEXT: ^[[BB3]]: // CHECK-NEXT: ^[[BB3]]:
^bb3: ^bb3:
omp.flush omp.flush

View File

@ -1,14 +1,14 @@
// RUN: mlir-opt -allow-unregistered-dialect -convert-scf-to-std %s | FileCheck %s // RUN: mlir-opt -allow-unregistered-dialect -convert-scf-to-cf %s | FileCheck %s
// CHECK-LABEL: func @simple_std_for_loop(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index) { // CHECK-LABEL: func @simple_std_for_loop(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index) {
// CHECK-NEXT: br ^bb1(%{{.*}} : index) // CHECK-NEXT: cf.br ^bb1(%{{.*}} : index)
// CHECK-NEXT: ^bb1(%{{.*}}: index): // 2 preds: ^bb0, ^bb2 // CHECK-NEXT: ^bb1(%{{.*}}: index): // 2 preds: ^bb0, ^bb2
// CHECK-NEXT: %{{.*}} = arith.cmpi slt, %{{.*}}, %{{.*}} : index // CHECK-NEXT: %{{.*}} = arith.cmpi slt, %{{.*}}, %{{.*}} : index
// CHECK-NEXT: cond_br %{{.*}}, ^bb2, ^bb3 // CHECK-NEXT: cf.cond_br %{{.*}}, ^bb2, ^bb3
// CHECK-NEXT: ^bb2: // pred: ^bb1 // CHECK-NEXT: ^bb2: // pred: ^bb1
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index // CHECK-NEXT: %{{.*}} = arith.constant 1 : index
// CHECK-NEXT: %[[iv:.*]] = arith.addi %{{.*}}, %{{.*}} : index // CHECK-NEXT: %[[iv:.*]] = arith.addi %{{.*}}, %{{.*}} : index
// CHECK-NEXT: br ^bb1(%[[iv]] : index) // CHECK-NEXT: cf.br ^bb1(%[[iv]] : index)
// CHECK-NEXT: ^bb3: // pred: ^bb1 // CHECK-NEXT: ^bb3: // pred: ^bb1
// CHECK-NEXT: return // CHECK-NEXT: return
func @simple_std_for_loop(%arg0 : index, %arg1 : index, %arg2 : index) { func @simple_std_for_loop(%arg0 : index, %arg1 : index, %arg2 : index) {
@ -19,23 +19,23 @@ func @simple_std_for_loop(%arg0 : index, %arg1 : index, %arg2 : index) {
} }
// CHECK-LABEL: func @simple_std_2_for_loops(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index) { // CHECK-LABEL: func @simple_std_2_for_loops(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index) {
// CHECK-NEXT: br ^bb1(%{{.*}} : index) // CHECK-NEXT: cf.br ^bb1(%{{.*}} : index)
// CHECK-NEXT: ^bb1(%[[ub0:.*]]: index): // 2 preds: ^bb0, ^bb5 // CHECK-NEXT: ^bb1(%[[ub0:.*]]: index): // 2 preds: ^bb0, ^bb5
// CHECK-NEXT: %[[cond0:.*]] = arith.cmpi slt, %[[ub0]], %{{.*}} : index // CHECK-NEXT: %[[cond0:.*]] = arith.cmpi slt, %[[ub0]], %{{.*}} : index
// CHECK-NEXT: cond_br %[[cond0]], ^bb2, ^bb6 // CHECK-NEXT: cf.cond_br %[[cond0]], ^bb2, ^bb6
// CHECK-NEXT: ^bb2: // pred: ^bb1 // CHECK-NEXT: ^bb2: // pred: ^bb1
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index // CHECK-NEXT: %{{.*}} = arith.constant 1 : index
// CHECK-NEXT: br ^bb3(%{{.*}} : index) // CHECK-NEXT: cf.br ^bb3(%{{.*}} : index)
// CHECK-NEXT: ^bb3(%[[ub1:.*]]: index): // 2 preds: ^bb2, ^bb4 // CHECK-NEXT: ^bb3(%[[ub1:.*]]: index): // 2 preds: ^bb2, ^bb4
// CHECK-NEXT: %[[cond1:.*]] = arith.cmpi slt, %{{.*}}, %{{.*}} : index // CHECK-NEXT: %[[cond1:.*]] = arith.cmpi slt, %{{.*}}, %{{.*}} : index
// CHECK-NEXT: cond_br %[[cond1]], ^bb4, ^bb5 // CHECK-NEXT: cf.cond_br %[[cond1]], ^bb4, ^bb5
// CHECK-NEXT: ^bb4: // pred: ^bb3 // CHECK-NEXT: ^bb4: // pred: ^bb3
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index // CHECK-NEXT: %{{.*}} = arith.constant 1 : index
// CHECK-NEXT: %[[iv1:.*]] = arith.addi %{{.*}}, %{{.*}} : index // CHECK-NEXT: %[[iv1:.*]] = arith.addi %{{.*}}, %{{.*}} : index
// CHECK-NEXT: br ^bb3(%[[iv1]] : index) // CHECK-NEXT: cf.br ^bb3(%[[iv1]] : index)
// CHECK-NEXT: ^bb5: // pred: ^bb3 // CHECK-NEXT: ^bb5: // pred: ^bb3
// CHECK-NEXT: %[[iv0:.*]] = arith.addi %{{.*}}, %{{.*}} : index // CHECK-NEXT: %[[iv0:.*]] = arith.addi %{{.*}}, %{{.*}} : index
// CHECK-NEXT: br ^bb1(%[[iv0]] : index) // CHECK-NEXT: cf.br ^bb1(%[[iv0]] : index)
// CHECK-NEXT: ^bb6: // pred: ^bb1 // CHECK-NEXT: ^bb6: // pred: ^bb1
// CHECK-NEXT: return // CHECK-NEXT: return
func @simple_std_2_for_loops(%arg0 : index, %arg1 : index, %arg2 : index) { func @simple_std_2_for_loops(%arg0 : index, %arg1 : index, %arg2 : index) {
@ -49,10 +49,10 @@ func @simple_std_2_for_loops(%arg0 : index, %arg1 : index, %arg2 : index) {
} }
// CHECK-LABEL: func @simple_std_if(%{{.*}}: i1) { // CHECK-LABEL: func @simple_std_if(%{{.*}}: i1) {
// CHECK-NEXT: cond_br %{{.*}}, ^bb1, ^bb2 // CHECK-NEXT: cf.cond_br %{{.*}}, ^bb1, ^bb2
// CHECK-NEXT: ^bb1: // pred: ^bb0 // CHECK-NEXT: ^bb1: // pred: ^bb0
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index // CHECK-NEXT: %{{.*}} = arith.constant 1 : index
// CHECK-NEXT: br ^bb2 // CHECK-NEXT: cf.br ^bb2
// CHECK-NEXT: ^bb2: // 2 preds: ^bb0, ^bb1 // CHECK-NEXT: ^bb2: // 2 preds: ^bb0, ^bb1
// CHECK-NEXT: return // CHECK-NEXT: return
func @simple_std_if(%arg0: i1) { func @simple_std_if(%arg0: i1) {
@ -63,13 +63,13 @@ func @simple_std_if(%arg0: i1) {
} }
// CHECK-LABEL: func @simple_std_if_else(%{{.*}}: i1) { // CHECK-LABEL: func @simple_std_if_else(%{{.*}}: i1) {
// CHECK-NEXT: cond_br %{{.*}}, ^bb1, ^bb2 // CHECK-NEXT: cf.cond_br %{{.*}}, ^bb1, ^bb2
// CHECK-NEXT: ^bb1: // pred: ^bb0 // CHECK-NEXT: ^bb1: // pred: ^bb0
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index // CHECK-NEXT: %{{.*}} = arith.constant 1 : index
// CHECK-NEXT: br ^bb3 // CHECK-NEXT: cf.br ^bb3
// CHECK-NEXT: ^bb2: // pred: ^bb0 // CHECK-NEXT: ^bb2: // pred: ^bb0
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index // CHECK-NEXT: %{{.*}} = arith.constant 1 : index
// CHECK-NEXT: br ^bb3 // CHECK-NEXT: cf.br ^bb3
// CHECK-NEXT: ^bb3: // 2 preds: ^bb1, ^bb2 // CHECK-NEXT: ^bb3: // 2 preds: ^bb1, ^bb2
// CHECK-NEXT: return // CHECK-NEXT: return
func @simple_std_if_else(%arg0: i1) { func @simple_std_if_else(%arg0: i1) {
@ -82,18 +82,18 @@ func @simple_std_if_else(%arg0: i1) {
} }
// CHECK-LABEL: func @simple_std_2_ifs(%{{.*}}: i1) { // CHECK-LABEL: func @simple_std_2_ifs(%{{.*}}: i1) {
// CHECK-NEXT: cond_br %{{.*}}, ^bb1, ^bb5 // CHECK-NEXT: cf.cond_br %{{.*}}, ^bb1, ^bb5
// CHECK-NEXT: ^bb1: // pred: ^bb0 // CHECK-NEXT: ^bb1: // pred: ^bb0
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index // CHECK-NEXT: %{{.*}} = arith.constant 1 : index
// CHECK-NEXT: cond_br %{{.*}}, ^bb2, ^bb3 // CHECK-NEXT: cf.cond_br %{{.*}}, ^bb2, ^bb3
// CHECK-NEXT: ^bb2: // pred: ^bb1 // CHECK-NEXT: ^bb2: // pred: ^bb1
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index // CHECK-NEXT: %{{.*}} = arith.constant 1 : index
// CHECK-NEXT: br ^bb4 // CHECK-NEXT: cf.br ^bb4
// CHECK-NEXT: ^bb3: // pred: ^bb1 // CHECK-NEXT: ^bb3: // pred: ^bb1
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index // CHECK-NEXT: %{{.*}} = arith.constant 1 : index
// CHECK-NEXT: br ^bb4 // CHECK-NEXT: cf.br ^bb4
// CHECK-NEXT: ^bb4: // 2 preds: ^bb2, ^bb3 // CHECK-NEXT: ^bb4: // 2 preds: ^bb2, ^bb3
// CHECK-NEXT: br ^bb5 // CHECK-NEXT: cf.br ^bb5
// CHECK-NEXT: ^bb5: // 2 preds: ^bb0, ^bb4 // CHECK-NEXT: ^bb5: // 2 preds: ^bb0, ^bb4
// CHECK-NEXT: return // CHECK-NEXT: return
func @simple_std_2_ifs(%arg0: i1) { func @simple_std_2_ifs(%arg0: i1) {
@ -109,27 +109,27 @@ func @simple_std_2_ifs(%arg0: i1) {
} }
// CHECK-LABEL: func @simple_std_for_loop_with_2_ifs(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: i1) { // CHECK-LABEL: func @simple_std_for_loop_with_2_ifs(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: i1) {
// CHECK-NEXT: br ^bb1(%{{.*}} : index) // CHECK-NEXT: cf.br ^bb1(%{{.*}} : index)
// CHECK-NEXT: ^bb1(%{{.*}}: index): // 2 preds: ^bb0, ^bb7 // CHECK-NEXT: ^bb1(%{{.*}}: index): // 2 preds: ^bb0, ^bb7
// CHECK-NEXT: %{{.*}} = arith.cmpi slt, %{{.*}}, %{{.*}} : index // CHECK-NEXT: %{{.*}} = arith.cmpi slt, %{{.*}}, %{{.*}} : index
// CHECK-NEXT: cond_br %{{.*}}, ^bb2, ^bb8 // CHECK-NEXT: cf.cond_br %{{.*}}, ^bb2, ^bb8
// CHECK-NEXT: ^bb2: // pred: ^bb1 // CHECK-NEXT: ^bb2: // pred: ^bb1
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index // CHECK-NEXT: %{{.*}} = arith.constant 1 : index
// CHECK-NEXT: cond_br %{{.*}}, ^bb3, ^bb7 // CHECK-NEXT: cf.cond_br %{{.*}}, ^bb3, ^bb7
// CHECK-NEXT: ^bb3: // pred: ^bb2 // CHECK-NEXT: ^bb3: // pred: ^bb2
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index // CHECK-NEXT: %{{.*}} = arith.constant 1 : index
// CHECK-NEXT: cond_br %{{.*}}, ^bb4, ^bb5 // CHECK-NEXT: cf.cond_br %{{.*}}, ^bb4, ^bb5
// CHECK-NEXT: ^bb4: // pred: ^bb3 // CHECK-NEXT: ^bb4: // pred: ^bb3
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index // CHECK-NEXT: %{{.*}} = arith.constant 1 : index
// CHECK-NEXT: br ^bb6 // CHECK-NEXT: cf.br ^bb6
// CHECK-NEXT: ^bb5: // pred: ^bb3 // CHECK-NEXT: ^bb5: // pred: ^bb3
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index // CHECK-NEXT: %{{.*}} = arith.constant 1 : index
// CHECK-NEXT: br ^bb6 // CHECK-NEXT: cf.br ^bb6
// CHECK-NEXT: ^bb6: // 2 preds: ^bb4, ^bb5 // CHECK-NEXT: ^bb6: // 2 preds: ^bb4, ^bb5
// CHECK-NEXT: br ^bb7 // CHECK-NEXT: cf.br ^bb7
// CHECK-NEXT: ^bb7: // 2 preds: ^bb2, ^bb6 // CHECK-NEXT: ^bb7: // 2 preds: ^bb2, ^bb6
// CHECK-NEXT: %[[iv0:.*]] = arith.addi %{{.*}}, %{{.*}} : index // CHECK-NEXT: %[[iv0:.*]] = arith.addi %{{.*}}, %{{.*}} : index
// CHECK-NEXT: br ^bb1(%[[iv0]] : index) // CHECK-NEXT: cf.br ^bb1(%[[iv0]] : index)
// CHECK-NEXT: ^bb8: // pred: ^bb1 // CHECK-NEXT: ^bb8: // pred: ^bb1
// CHECK-NEXT: return // CHECK-NEXT: return
// CHECK-NEXT: } // CHECK-NEXT: }
@ -150,12 +150,12 @@ func @simple_std_for_loop_with_2_ifs(%arg0 : index, %arg1 : index, %arg2 : index
// CHECK-LABEL: func @simple_if_yield // CHECK-LABEL: func @simple_if_yield
func @simple_if_yield(%arg0: i1) -> (i1, i1) { func @simple_if_yield(%arg0: i1) -> (i1, i1) {
// CHECK: cond_br %{{.*}}, ^[[then:.*]], ^[[else:.*]] // CHECK: cf.cond_br %{{.*}}, ^[[then:.*]], ^[[else:.*]]
%0:2 = scf.if %arg0 -> (i1, i1) { %0:2 = scf.if %arg0 -> (i1, i1) {
// CHECK: ^[[then]]: // CHECK: ^[[then]]:
// CHECK: %[[v0:.*]] = arith.constant false // CHECK: %[[v0:.*]] = arith.constant false
// CHECK: %[[v1:.*]] = arith.constant true // CHECK: %[[v1:.*]] = arith.constant true
// CHECK: br ^[[dom:.*]](%[[v0]], %[[v1]] : i1, i1) // CHECK: cf.br ^[[dom:.*]](%[[v0]], %[[v1]] : i1, i1)
%c0 = arith.constant false %c0 = arith.constant false
%c1 = arith.constant true %c1 = arith.constant true
scf.yield %c0, %c1 : i1, i1 scf.yield %c0, %c1 : i1, i1
@ -163,13 +163,13 @@ func @simple_if_yield(%arg0: i1) -> (i1, i1) {
// CHECK: ^[[else]]: // CHECK: ^[[else]]:
// CHECK: %[[v2:.*]] = arith.constant false // CHECK: %[[v2:.*]] = arith.constant false
// CHECK: %[[v3:.*]] = arith.constant true // CHECK: %[[v3:.*]] = arith.constant true
// CHECK: br ^[[dom]](%[[v3]], %[[v2]] : i1, i1) // CHECK: cf.br ^[[dom]](%[[v3]], %[[v2]] : i1, i1)
%c0 = arith.constant false %c0 = arith.constant false
%c1 = arith.constant true %c1 = arith.constant true
scf.yield %c1, %c0 : i1, i1 scf.yield %c1, %c0 : i1, i1
} }
// CHECK: ^[[dom]](%[[arg1:.*]]: i1, %[[arg2:.*]]: i1): // CHECK: ^[[dom]](%[[arg1:.*]]: i1, %[[arg2:.*]]: i1):
// CHECK: br ^[[cont:.*]] // CHECK: cf.br ^[[cont:.*]]
// CHECK: ^[[cont]]: // CHECK: ^[[cont]]:
// CHECK: return %[[arg1]], %[[arg2]] // CHECK: return %[[arg1]], %[[arg2]]
return %0#0, %0#1 : i1, i1 return %0#0, %0#1 : i1, i1
@ -177,49 +177,49 @@ func @simple_if_yield(%arg0: i1) -> (i1, i1) {
// CHECK-LABEL: func @nested_if_yield // CHECK-LABEL: func @nested_if_yield
func @nested_if_yield(%arg0: i1) -> (index) { func @nested_if_yield(%arg0: i1) -> (index) {
// CHECK: cond_br %{{.*}}, ^[[first_then:.*]], ^[[first_else:.*]] // CHECK: cf.cond_br %{{.*}}, ^[[first_then:.*]], ^[[first_else:.*]]
%0 = scf.if %arg0 -> i1 { %0 = scf.if %arg0 -> i1 {
// CHECK: ^[[first_then]]: // CHECK: ^[[first_then]]:
%1 = arith.constant true %1 = arith.constant true
// CHECK: br ^[[first_dom:.*]]({{.*}}) // CHECK: cf.br ^[[first_dom:.*]]({{.*}})
scf.yield %1 : i1 scf.yield %1 : i1
} else { } else {
// CHECK: ^[[first_else]]: // CHECK: ^[[first_else]]:
%2 = arith.constant false %2 = arith.constant false
// CHECK: br ^[[first_dom]]({{.*}}) // CHECK: cf.br ^[[first_dom]]({{.*}})
scf.yield %2 : i1 scf.yield %2 : i1
} }
// CHECK: ^[[first_dom]](%[[arg1:.*]]: i1): // CHECK: ^[[first_dom]](%[[arg1:.*]]: i1):
// CHECK: br ^[[first_cont:.*]] // CHECK: cf.br ^[[first_cont:.*]]
// CHECK: ^[[first_cont]]: // CHECK: ^[[first_cont]]:
// CHECK: cond_br %[[arg1]], ^[[second_outer_then:.*]], ^[[second_outer_else:.*]] // CHECK: cf.cond_br %[[arg1]], ^[[second_outer_then:.*]], ^[[second_outer_else:.*]]
%1 = scf.if %0 -> index { %1 = scf.if %0 -> index {
// CHECK: ^[[second_outer_then]]: // CHECK: ^[[second_outer_then]]:
// CHECK: cond_br %arg0, ^[[second_inner_then:.*]], ^[[second_inner_else:.*]] // CHECK: cf.cond_br %arg0, ^[[second_inner_then:.*]], ^[[second_inner_else:.*]]
%3 = scf.if %arg0 -> index { %3 = scf.if %arg0 -> index {
// CHECK: ^[[second_inner_then]]: // CHECK: ^[[second_inner_then]]:
%4 = arith.constant 40 : index %4 = arith.constant 40 : index
// CHECK: br ^[[second_inner_dom:.*]]({{.*}}) // CHECK: cf.br ^[[second_inner_dom:.*]]({{.*}})
scf.yield %4 : index scf.yield %4 : index
} else { } else {
// CHECK: ^[[second_inner_else]]: // CHECK: ^[[second_inner_else]]:
%5 = arith.constant 41 : index %5 = arith.constant 41 : index
// CHECK: br ^[[second_inner_dom]]({{.*}}) // CHECK: cf.br ^[[second_inner_dom]]({{.*}})
scf.yield %5 : index scf.yield %5 : index
} }
// CHECK: ^[[second_inner_dom]](%[[arg2:.*]]: index): // CHECK: ^[[second_inner_dom]](%[[arg2:.*]]: index):
// CHECK: br ^[[second_inner_cont:.*]] // CHECK: cf.br ^[[second_inner_cont:.*]]
// CHECK: ^[[second_inner_cont]]: // CHECK: ^[[second_inner_cont]]:
// CHECK: br ^[[second_outer_dom:.*]]({{.*}}) // CHECK: cf.br ^[[second_outer_dom:.*]]({{.*}})
scf.yield %3 : index scf.yield %3 : index
} else { } else {
// CHECK: ^[[second_outer_else]]: // CHECK: ^[[second_outer_else]]:
%6 = arith.constant 42 : index %6 = arith.constant 42 : index
// CHECK: br ^[[second_outer_dom]]({{.*}} // CHECK: cf.br ^[[second_outer_dom]]({{.*}}
scf.yield %6 : index scf.yield %6 : index
} }
// CHECK: ^[[second_outer_dom]](%[[arg3:.*]]: index): // CHECK: ^[[second_outer_dom]](%[[arg3:.*]]: index):
// CHECK: br ^[[second_outer_cont:.*]] // CHECK: cf.br ^[[second_outer_cont:.*]]
// CHECK: ^[[second_outer_cont]]: // CHECK: ^[[second_outer_cont]]:
// CHECK: return %[[arg3]] : index // CHECK: return %[[arg3]] : index
return %1 : index return %1 : index
@ -228,22 +228,22 @@ func @nested_if_yield(%arg0: i1) -> (index) {
// CHECK-LABEL: func @parallel_loop( // CHECK-LABEL: func @parallel_loop(
// CHECK-SAME: [[VAL_0:%.*]]: index, [[VAL_1:%.*]]: index, [[VAL_2:%.*]]: index, [[VAL_3:%.*]]: index, [[VAL_4:%.*]]: index) { // CHECK-SAME: [[VAL_0:%.*]]: index, [[VAL_1:%.*]]: index, [[VAL_2:%.*]]: index, [[VAL_3:%.*]]: index, [[VAL_4:%.*]]: index) {
// CHECK: [[VAL_5:%.*]] = arith.constant 1 : index // CHECK: [[VAL_5:%.*]] = arith.constant 1 : index
// CHECK: br ^bb1([[VAL_0]] : index) // CHECK: cf.br ^bb1([[VAL_0]] : index)
// CHECK: ^bb1([[VAL_6:%.*]]: index): // CHECK: ^bb1([[VAL_6:%.*]]: index):
// CHECK: [[VAL_7:%.*]] = arith.cmpi slt, [[VAL_6]], [[VAL_2]] : index // CHECK: [[VAL_7:%.*]] = arith.cmpi slt, [[VAL_6]], [[VAL_2]] : index
// CHECK: cond_br [[VAL_7]], ^bb2, ^bb6 // CHECK: cf.cond_br [[VAL_7]], ^bb2, ^bb6
// CHECK: ^bb2: // CHECK: ^bb2:
// CHECK: br ^bb3([[VAL_1]] : index) // CHECK: cf.br ^bb3([[VAL_1]] : index)
// CHECK: ^bb3([[VAL_8:%.*]]: index): // CHECK: ^bb3([[VAL_8:%.*]]: index):
// CHECK: [[VAL_9:%.*]] = arith.cmpi slt, [[VAL_8]], [[VAL_3]] : index // CHECK: [[VAL_9:%.*]] = arith.cmpi slt, [[VAL_8]], [[VAL_3]] : index
// CHECK: cond_br [[VAL_9]], ^bb4, ^bb5 // CHECK: cf.cond_br [[VAL_9]], ^bb4, ^bb5
// CHECK: ^bb4: // CHECK: ^bb4:
// CHECK: [[VAL_10:%.*]] = arith.constant 1 : index // CHECK: [[VAL_10:%.*]] = arith.constant 1 : index
// CHECK: [[VAL_11:%.*]] = arith.addi [[VAL_8]], [[VAL_5]] : index // CHECK: [[VAL_11:%.*]] = arith.addi [[VAL_8]], [[VAL_5]] : index
// CHECK: br ^bb3([[VAL_11]] : index) // CHECK: cf.br ^bb3([[VAL_11]] : index)
// CHECK: ^bb5: // CHECK: ^bb5:
// CHECK: [[VAL_12:%.*]] = arith.addi [[VAL_6]], [[VAL_4]] : index // CHECK: [[VAL_12:%.*]] = arith.addi [[VAL_6]], [[VAL_4]] : index
// CHECK: br ^bb1([[VAL_12]] : index) // CHECK: cf.br ^bb1([[VAL_12]] : index)
// CHECK: ^bb6: // CHECK: ^bb6:
// CHECK: return // CHECK: return
// CHECK: } // CHECK: }
@ -262,16 +262,16 @@ func @parallel_loop(%arg0 : index, %arg1 : index, %arg2 : index,
// CHECK-SAME: (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index) // CHECK-SAME: (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index)
// CHECK: %[[INIT0:.*]] = arith.constant 0 // CHECK: %[[INIT0:.*]] = arith.constant 0
// CHECK: %[[INIT1:.*]] = arith.constant 1 // CHECK: %[[INIT1:.*]] = arith.constant 1
// CHECK: br ^[[COND:.*]](%[[LB]], %[[INIT0]], %[[INIT1]] : index, f32, f32) // CHECK: cf.br ^[[COND:.*]](%[[LB]], %[[INIT0]], %[[INIT1]] : index, f32, f32)
// //
// CHECK: ^[[COND]](%[[ITER:.*]]: index, %[[ITER_ARG0:.*]]: f32, %[[ITER_ARG1:.*]]: f32): // CHECK: ^[[COND]](%[[ITER:.*]]: index, %[[ITER_ARG0:.*]]: f32, %[[ITER_ARG1:.*]]: f32):
// CHECK: %[[CMP:.*]] = arith.cmpi slt, %[[ITER]], %[[UB]] : index // CHECK: %[[CMP:.*]] = arith.cmpi slt, %[[ITER]], %[[UB]] : index
// CHECK: cond_br %[[CMP]], ^[[BODY:.*]], ^[[CONTINUE:.*]] // CHECK: cf.cond_br %[[CMP]], ^[[BODY:.*]], ^[[CONTINUE:.*]]
// //
// CHECK: ^[[BODY]]: // CHECK: ^[[BODY]]:
// CHECK: %[[SUM:.*]] = arith.addf %[[ITER_ARG0]], %[[ITER_ARG1]] : f32 // CHECK: %[[SUM:.*]] = arith.addf %[[ITER_ARG0]], %[[ITER_ARG1]] : f32
// CHECK: %[[STEPPED:.*]] = arith.addi %[[ITER]], %[[STEP]] : index // CHECK: %[[STEPPED:.*]] = arith.addi %[[ITER]], %[[STEP]] : index
// CHECK: br ^[[COND]](%[[STEPPED]], %[[SUM]], %[[SUM]] : index, f32, f32) // CHECK: cf.br ^[[COND]](%[[STEPPED]], %[[SUM]], %[[SUM]] : index, f32, f32)
// //
// CHECK: ^[[CONTINUE]]: // CHECK: ^[[CONTINUE]]:
// CHECK: return %[[ITER_ARG0]], %[[ITER_ARG1]] : f32, f32 // CHECK: return %[[ITER_ARG0]], %[[ITER_ARG1]] : f32, f32
@ -288,18 +288,18 @@ func @for_yield(%arg0 : index, %arg1 : index, %arg2 : index) -> (f32, f32) {
// CHECK-LABEL: @nested_for_yield // CHECK-LABEL: @nested_for_yield
// CHECK-SAME: (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index) // CHECK-SAME: (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index)
// CHECK: %[[INIT:.*]] = arith.constant // CHECK: %[[INIT:.*]] = arith.constant
// CHECK: br ^[[COND_OUT:.*]](%[[LB]], %[[INIT]] : index, f32) // CHECK: cf.br ^[[COND_OUT:.*]](%[[LB]], %[[INIT]] : index, f32)
// CHECK: ^[[COND_OUT]](%[[ITER_OUT:.*]]: index, %[[ARG_OUT:.*]]: f32): // CHECK: ^[[COND_OUT]](%[[ITER_OUT:.*]]: index, %[[ARG_OUT:.*]]: f32):
// CHECK: cond_br %{{.*}}, ^[[BODY_OUT:.*]], ^[[CONT_OUT:.*]] // CHECK: cf.cond_br %{{.*}}, ^[[BODY_OUT:.*]], ^[[CONT_OUT:.*]]
// CHECK: ^[[BODY_OUT]]: // CHECK: ^[[BODY_OUT]]:
// CHECK: br ^[[COND_IN:.*]](%[[LB]], %[[ARG_OUT]] : index, f32) // CHECK: cf.br ^[[COND_IN:.*]](%[[LB]], %[[ARG_OUT]] : index, f32)
// CHECK: ^[[COND_IN]](%[[ITER_IN:.*]]: index, %[[ARG_IN:.*]]: f32): // CHECK: ^[[COND_IN]](%[[ITER_IN:.*]]: index, %[[ARG_IN:.*]]: f32):
// CHECK: cond_br %{{.*}}, ^[[BODY_IN:.*]], ^[[CONT_IN:.*]] // CHECK: cf.cond_br %{{.*}}, ^[[BODY_IN:.*]], ^[[CONT_IN:.*]]
// CHECK: ^[[BODY_IN]] // CHECK: ^[[BODY_IN]]
// CHECK: %[[RES:.*]] = arith.addf // CHECK: %[[RES:.*]] = arith.addf
// CHECK: br ^[[COND_IN]](%{{.*}}, %[[RES]] : index, f32) // CHECK: cf.br ^[[COND_IN]](%{{.*}}, %[[RES]] : index, f32)
// CHECK: ^[[CONT_IN]]: // CHECK: ^[[CONT_IN]]:
// CHECK: br ^[[COND_OUT]](%{{.*}}, %[[ARG_IN]] : index, f32) // CHECK: cf.br ^[[COND_OUT]](%{{.*}}, %[[ARG_IN]] : index, f32)
// CHECK: ^[[CONT_OUT]]: // CHECK: ^[[CONT_OUT]]:
// CHECK: return %[[ARG_OUT]] : f32 // CHECK: return %[[ARG_OUT]] : f32
func @nested_for_yield(%arg0 : index, %arg1 : index, %arg2 : index) -> f32 { func @nested_for_yield(%arg0 : index, %arg1 : index, %arg2 : index) -> f32 {
@ -325,13 +325,13 @@ func @simple_parallel_reduce_loop(%arg0: index, %arg1: index,
// passed across as a block argument. // passed across as a block argument.
// Branch to the condition block passing in the initial reduction value. // Branch to the condition block passing in the initial reduction value.
// CHECK: br ^[[COND:.*]](%[[LB]], %[[INIT]] // CHECK: cf.br ^[[COND:.*]](%[[LB]], %[[INIT]]
// Condition branch takes as arguments the current value of the iteration // Condition branch takes as arguments the current value of the iteration
// variable and the current partially reduced value. // variable and the current partially reduced value.
// CHECK: ^[[COND]](%[[ITER:.*]]: index, %[[ITER_ARG:.*]]: f32 // CHECK: ^[[COND]](%[[ITER:.*]]: index, %[[ITER_ARG:.*]]: f32
// CHECK: %[[COMP:.*]] = arith.cmpi slt, %[[ITER]], %[[UB]] // CHECK: %[[COMP:.*]] = arith.cmpi slt, %[[ITER]], %[[UB]]
// CHECK: cond_br %[[COMP]], ^[[BODY:.*]], ^[[CONTINUE:.*]] // CHECK: cf.cond_br %[[COMP]], ^[[BODY:.*]], ^[[CONTINUE:.*]]
// Bodies of scf.reduce operations are folded into the main loop body. The // Bodies of scf.reduce operations are folded into the main loop body. The
// result of this partial reduction is passed as argument to the condition // result of this partial reduction is passed as argument to the condition
@ -340,7 +340,7 @@ func @simple_parallel_reduce_loop(%arg0: index, %arg1: index,
// CHECK: %[[CST:.*]] = arith.constant 4.2 // CHECK: %[[CST:.*]] = arith.constant 4.2
// CHECK: %[[PROD:.*]] = arith.mulf %[[ITER_ARG]], %[[CST]] // CHECK: %[[PROD:.*]] = arith.mulf %[[ITER_ARG]], %[[CST]]
// CHECK: %[[INCR:.*]] = arith.addi %[[ITER]], %[[STEP]] // CHECK: %[[INCR:.*]] = arith.addi %[[ITER]], %[[STEP]]
// CHECK: br ^[[COND]](%[[INCR]], %[[PROD]] // CHECK: cf.br ^[[COND]](%[[INCR]], %[[PROD]]
// The continuation block has access to the (last value of) reduction. // The continuation block has access to the (last value of) reduction.
// CHECK: ^[[CONTINUE]]: // CHECK: ^[[CONTINUE]]:
@ -363,19 +363,19 @@ func @parallel_reduce_loop(%arg0 : index, %arg1 : index, %arg2 : index,
// Multiple reduction blocks should be folded in the same body, and the // Multiple reduction blocks should be folded in the same body, and the
// reduction value must be forwarded through block structures. // reduction value must be forwarded through block structures.
// CHECK: %[[INIT2:.*]] = arith.constant 42 // CHECK: %[[INIT2:.*]] = arith.constant 42
// CHECK: br ^[[COND_OUT:.*]](%{{.*}}, %[[INIT1]], %[[INIT2]] // CHECK: cf.br ^[[COND_OUT:.*]](%{{.*}}, %[[INIT1]], %[[INIT2]]
// CHECK: ^[[COND_OUT]](%{{.*}}: index, %[[ITER_ARG1_OUT:.*]]: f32, %[[ITER_ARG2_OUT:.*]]: i64 // CHECK: ^[[COND_OUT]](%{{.*}}: index, %[[ITER_ARG1_OUT:.*]]: f32, %[[ITER_ARG2_OUT:.*]]: i64
// CHECK: cond_br %{{.*}}, ^[[BODY_OUT:.*]], ^[[CONT_OUT:.*]] // CHECK: cf.cond_br %{{.*}}, ^[[BODY_OUT:.*]], ^[[CONT_OUT:.*]]
// CHECK: ^[[BODY_OUT]]: // CHECK: ^[[BODY_OUT]]:
// CHECK: br ^[[COND_IN:.*]](%{{.*}}, %[[ITER_ARG1_OUT]], %[[ITER_ARG2_OUT]] // CHECK: cf.br ^[[COND_IN:.*]](%{{.*}}, %[[ITER_ARG1_OUT]], %[[ITER_ARG2_OUT]]
// CHECK: ^[[COND_IN]](%{{.*}}: index, %[[ITER_ARG1_IN:.*]]: f32, %[[ITER_ARG2_IN:.*]]: i64 // CHECK: ^[[COND_IN]](%{{.*}}: index, %[[ITER_ARG1_IN:.*]]: f32, %[[ITER_ARG2_IN:.*]]: i64
// CHECK: cond_br %{{.*}}, ^[[BODY_IN:.*]], ^[[CONT_IN:.*]] // CHECK: cf.cond_br %{{.*}}, ^[[BODY_IN:.*]], ^[[CONT_IN:.*]]
// CHECK: ^[[BODY_IN]]: // CHECK: ^[[BODY_IN]]:
// CHECK: %[[REDUCE1:.*]] = arith.addf %[[ITER_ARG1_IN]], %{{.*}} // CHECK: %[[REDUCE1:.*]] = arith.addf %[[ITER_ARG1_IN]], %{{.*}}
// CHECK: %[[REDUCE2:.*]] = arith.ori %[[ITER_ARG2_IN]], %{{.*}} // CHECK: %[[REDUCE2:.*]] = arith.ori %[[ITER_ARG2_IN]], %{{.*}}
// CHECK: br ^[[COND_IN]](%{{.*}}, %[[REDUCE1]], %[[REDUCE2]] // CHECK: cf.br ^[[COND_IN]](%{{.*}}, %[[REDUCE1]], %[[REDUCE2]]
// CHECK: ^[[CONT_IN]]: // CHECK: ^[[CONT_IN]]:
// CHECK: br ^[[COND_OUT]](%{{.*}}, %[[ITER_ARG1_IN]], %[[ITER_ARG2_IN]] // CHECK: cf.br ^[[COND_OUT]](%{{.*}}, %[[ITER_ARG1_IN]], %[[ITER_ARG2_IN]]
// CHECK: ^[[CONT_OUT]]: // CHECK: ^[[CONT_OUT]]:
// CHECK: return %[[ITER_ARG1_OUT]], %[[ITER_ARG2_OUT]] // CHECK: return %[[ITER_ARG1_OUT]], %[[ITER_ARG2_OUT]]
%step = arith.constant 1 : index %step = arith.constant 1 : index
@ -416,17 +416,17 @@ func @unknown_op_inside_loop(%arg0: index, %arg1: index, %arg2: index) {
// CHECK-LABEL: @minimal_while // CHECK-LABEL: @minimal_while
func @minimal_while() { func @minimal_while() {
// CHECK: %[[COND:.*]] = "test.make_condition"() : () -> i1 // CHECK: %[[COND:.*]] = "test.make_condition"() : () -> i1
// CHECK: br ^[[BEFORE:.*]] // CHECK: cf.br ^[[BEFORE:.*]]
%0 = "test.make_condition"() : () -> i1 %0 = "test.make_condition"() : () -> i1
scf.while : () -> () { scf.while : () -> () {
// CHECK: ^[[BEFORE]]: // CHECK: ^[[BEFORE]]:
// CHECK: cond_br %[[COND]], ^[[AFTER:.*]], ^[[CONT:.*]] // CHECK: cf.cond_br %[[COND]], ^[[AFTER:.*]], ^[[CONT:.*]]
scf.condition(%0) scf.condition(%0)
} do { } do {
// CHECK: ^[[AFTER]]: // CHECK: ^[[AFTER]]:
// CHECK: "test.some_payload"() : () -> () // CHECK: "test.some_payload"() : () -> ()
"test.some_payload"() : () -> () "test.some_payload"() : () -> ()
// CHECK: br ^[[BEFORE]] // CHECK: cf.br ^[[BEFORE]]
scf.yield scf.yield
} }
// CHECK: ^[[CONT]]: // CHECK: ^[[CONT]]:
@ -436,16 +436,16 @@ func @minimal_while() {
// CHECK-LABEL: @do_while // CHECK-LABEL: @do_while
func @do_while(%arg0: f32) { func @do_while(%arg0: f32) {
// CHECK: br ^[[BEFORE:.*]]({{.*}}: f32) // CHECK: cf.br ^[[BEFORE:.*]]({{.*}}: f32)
scf.while (%arg1 = %arg0) : (f32) -> (f32) { scf.while (%arg1 = %arg0) : (f32) -> (f32) {
// CHECK: ^[[BEFORE]](%[[VAL:.*]]: f32): // CHECK: ^[[BEFORE]](%[[VAL:.*]]: f32):
// CHECK: %[[COND:.*]] = "test.make_condition"() : () -> i1 // CHECK: %[[COND:.*]] = "test.make_condition"() : () -> i1
%0 = "test.make_condition"() : () -> i1 %0 = "test.make_condition"() : () -> i1
// CHECK: cond_br %[[COND]], ^[[BEFORE]](%[[VAL]] : f32), ^[[CONT:.*]] // CHECK: cf.cond_br %[[COND]], ^[[BEFORE]](%[[VAL]] : f32), ^[[CONT:.*]]
scf.condition(%0) %arg1 : f32 scf.condition(%0) %arg1 : f32
} do { } do {
^bb0(%arg2: f32): ^bb0(%arg2: f32):
// CHECK-NOT: br ^[[BEFORE]] // CHECK-NOT: cf.br ^[[BEFORE]]
scf.yield %arg2 : f32 scf.yield %arg2 : f32
} }
// CHECK: ^[[CONT]]: // CHECK: ^[[CONT]]:
@ -460,21 +460,21 @@ func @while_values(%arg0: i32, %arg1: f32) {
%0 = "test.make_condition"() : () -> i1 %0 = "test.make_condition"() : () -> i1
%c0_i32 = arith.constant 0 : i32 %c0_i32 = arith.constant 0 : i32
%cst = arith.constant 0.000000e+00 : f32 %cst = arith.constant 0.000000e+00 : f32
// CHECK: br ^[[BEFORE:.*]](%[[ARG0]], %[[ARG1]] : i32, f32) // CHECK: cf.br ^[[BEFORE:.*]](%[[ARG0]], %[[ARG1]] : i32, f32)
%1:2 = scf.while (%arg2 = %arg0, %arg3 = %arg1) : (i32, f32) -> (i64, f64) { %1:2 = scf.while (%arg2 = %arg0, %arg3 = %arg1) : (i32, f32) -> (i64, f64) {
// CHECK: ^bb1(%[[ARG2:.*]]: i32, %[[ARG3:.]]: f32): // CHECK: ^bb1(%[[ARG2:.*]]: i32, %[[ARG3:.]]: f32):
// CHECK: %[[VAL1:.*]] = arith.extui %[[ARG0]] : i32 to i64 // CHECK: %[[VAL1:.*]] = arith.extui %[[ARG0]] : i32 to i64
%2 = arith.extui %arg0 : i32 to i64 %2 = arith.extui %arg0 : i32 to i64
// CHECK: %[[VAL2:.*]] = arith.extf %[[ARG3]] : f32 to f64 // CHECK: %[[VAL2:.*]] = arith.extf %[[ARG3]] : f32 to f64
%3 = arith.extf %arg3 : f32 to f64 %3 = arith.extf %arg3 : f32 to f64
// CHECK: cond_br %[[COND]], // CHECK: cf.cond_br %[[COND]],
// CHECK: ^[[AFTER:.*]](%[[VAL1]], %[[VAL2]] : i64, f64), // CHECK: ^[[AFTER:.*]](%[[VAL1]], %[[VAL2]] : i64, f64),
// CHECK: ^[[CONT:.*]] // CHECK: ^[[CONT:.*]]
scf.condition(%0) %2, %3 : i64, f64 scf.condition(%0) %2, %3 : i64, f64
} do { } do {
// CHECK: ^[[AFTER]](%[[ARG4:.*]]: i64, %[[ARG5:.*]]: f64): // CHECK: ^[[AFTER]](%[[ARG4:.*]]: i64, %[[ARG5:.*]]: f64):
^bb0(%arg2: i64, %arg3: f64): ^bb0(%arg2: i64, %arg3: f64):
// CHECK: br ^[[BEFORE]](%{{.*}}, %{{.*}} : i32, f32) // CHECK: cf.br ^[[BEFORE]](%{{.*}}, %{{.*}} : i32, f32)
scf.yield %c0_i32, %cst : i32, f32 scf.yield %c0_i32, %cst : i32, f32
} }
// CHECK: ^bb3: // CHECK: ^bb3:
@ -484,17 +484,17 @@ func @while_values(%arg0: i32, %arg1: f32) {
// CHECK-LABEL: @nested_while_ops // CHECK-LABEL: @nested_while_ops
func @nested_while_ops(%arg0: f32) -> i64 { func @nested_while_ops(%arg0: f32) -> i64 {
// CHECK: br ^[[OUTER_BEFORE:.*]](%{{.*}} : f32) // CHECK: cf.br ^[[OUTER_BEFORE:.*]](%{{.*}} : f32)
%0 = scf.while(%outer = %arg0) : (f32) -> i64 { %0 = scf.while(%outer = %arg0) : (f32) -> i64 {
// CHECK: ^[[OUTER_BEFORE]](%{{.*}}: f32): // CHECK: ^[[OUTER_BEFORE]](%{{.*}}: f32):
// CHECK: %[[OUTER_COND:.*]] = "test.outer_before_pre"() : () -> i1 // CHECK: %[[OUTER_COND:.*]] = "test.outer_before_pre"() : () -> i1
%cond = "test.outer_before_pre"() : () -> i1 %cond = "test.outer_before_pre"() : () -> i1
// CHECK: br ^[[INNER_BEFORE_BEFORE:.*]](%{{.*}} : f32) // CHECK: cf.br ^[[INNER_BEFORE_BEFORE:.*]](%{{.*}} : f32)
%1 = scf.while(%inner = %outer) : (f32) -> i64 { %1 = scf.while(%inner = %outer) : (f32) -> i64 {
// CHECK: ^[[INNER_BEFORE_BEFORE]](%{{.*}}: f32): // CHECK: ^[[INNER_BEFORE_BEFORE]](%{{.*}}: f32):
// CHECK: %[[INNER1:.*]]:2 = "test.inner_before"(%{{.*}}) : (f32) -> (i1, i64) // CHECK: %[[INNER1:.*]]:2 = "test.inner_before"(%{{.*}}) : (f32) -> (i1, i64)
%2:2 = "test.inner_before"(%inner) : (f32) -> (i1, i64) %2:2 = "test.inner_before"(%inner) : (f32) -> (i1, i64)
// CHECK: cond_br %[[INNER1]]#0, // CHECK: cf.cond_br %[[INNER1]]#0,
// CHECK: ^[[INNER_BEFORE_AFTER:.*]](%[[INNER1]]#1 : i64), // CHECK: ^[[INNER_BEFORE_AFTER:.*]](%[[INNER1]]#1 : i64),
// CHECK: ^[[OUTER_BEFORE_LAST:.*]] // CHECK: ^[[OUTER_BEFORE_LAST:.*]]
scf.condition(%2#0) %2#1 : i64 scf.condition(%2#0) %2#1 : i64
@ -503,13 +503,13 @@ func @nested_while_ops(%arg0: f32) -> i64 {
^bb0(%arg1: i64): ^bb0(%arg1: i64):
// CHECK: %[[INNER2:.*]] = "test.inner_after"(%{{.*}}) : (i64) -> f32 // CHECK: %[[INNER2:.*]] = "test.inner_after"(%{{.*}}) : (i64) -> f32
%3 = "test.inner_after"(%arg1) : (i64) -> f32 %3 = "test.inner_after"(%arg1) : (i64) -> f32
// CHECK: br ^[[INNER_BEFORE_BEFORE]](%[[INNER2]] : f32) // CHECK: cf.br ^[[INNER_BEFORE_BEFORE]](%[[INNER2]] : f32)
scf.yield %3 : f32 scf.yield %3 : f32
} }
// CHECK: ^[[OUTER_BEFORE_LAST]]: // CHECK: ^[[OUTER_BEFORE_LAST]]:
// CHECK: "test.outer_before_post"() : () -> () // CHECK: "test.outer_before_post"() : () -> ()
"test.outer_before_post"() : () -> () "test.outer_before_post"() : () -> ()
// CHECK: cond_br %[[OUTER_COND]], // CHECK: cf.cond_br %[[OUTER_COND]],
// CHECK: ^[[OUTER_AFTER:.*]](%[[INNER1]]#1 : i64), // CHECK: ^[[OUTER_AFTER:.*]](%[[INNER1]]#1 : i64),
// CHECK: ^[[CONTINUATION:.*]] // CHECK: ^[[CONTINUATION:.*]]
scf.condition(%cond) %1 : i64 scf.condition(%cond) %1 : i64
@ -518,12 +518,12 @@ func @nested_while_ops(%arg0: f32) -> i64 {
^bb2(%arg2: i64): ^bb2(%arg2: i64):
// CHECK: "test.outer_after_pre"(%{{.*}}) : (i64) -> () // CHECK: "test.outer_after_pre"(%{{.*}}) : (i64) -> ()
"test.outer_after_pre"(%arg2) : (i64) -> () "test.outer_after_pre"(%arg2) : (i64) -> ()
// CHECK: br ^[[INNER_AFTER_BEFORE:.*]](%{{.*}} : i64) // CHECK: cf.br ^[[INNER_AFTER_BEFORE:.*]](%{{.*}} : i64)
%4 = scf.while(%inner = %arg2) : (i64) -> f32 { %4 = scf.while(%inner = %arg2) : (i64) -> f32 {
// CHECK: ^[[INNER_AFTER_BEFORE]](%{{.*}}: i64): // CHECK: ^[[INNER_AFTER_BEFORE]](%{{.*}}: i64):
// CHECK: %[[INNER3:.*]]:2 = "test.inner2_before"(%{{.*}}) : (i64) -> (i1, f32) // CHECK: %[[INNER3:.*]]:2 = "test.inner2_before"(%{{.*}}) : (i64) -> (i1, f32)
%5:2 = "test.inner2_before"(%inner) : (i64) -> (i1, f32) %5:2 = "test.inner2_before"(%inner) : (i64) -> (i1, f32)
// CHECK: cond_br %[[INNER3]]#0, // CHECK: cf.cond_br %[[INNER3]]#0,
// CHECK: ^[[INNER_AFTER_AFTER:.*]](%[[INNER3]]#1 : f32), // CHECK: ^[[INNER_AFTER_AFTER:.*]](%[[INNER3]]#1 : f32),
// CHECK: ^[[OUTER_AFTER_LAST:.*]] // CHECK: ^[[OUTER_AFTER_LAST:.*]]
scf.condition(%5#0) %5#1 : f32 scf.condition(%5#0) %5#1 : f32
@ -532,13 +532,13 @@ func @nested_while_ops(%arg0: f32) -> i64 {
^bb3(%arg3: f32): ^bb3(%arg3: f32):
// CHECK: %{{.*}} = "test.inner2_after"(%{{.*}}) : (f32) -> i64 // CHECK: %{{.*}} = "test.inner2_after"(%{{.*}}) : (f32) -> i64
%6 = "test.inner2_after"(%arg3) : (f32) -> i64 %6 = "test.inner2_after"(%arg3) : (f32) -> i64
// CHECK: br ^[[INNER_AFTER_BEFORE]](%{{.*}} : i64) // CHECK: cf.br ^[[INNER_AFTER_BEFORE]](%{{.*}} : i64)
scf.yield %6 : i64 scf.yield %6 : i64
} }
// CHECK: ^[[OUTER_AFTER_LAST]]: // CHECK: ^[[OUTER_AFTER_LAST]]:
// CHECK: "test.outer_after_post"() : () -> () // CHECK: "test.outer_after_post"() : () -> ()
"test.outer_after_post"() : () -> () "test.outer_after_post"() : () -> ()
// CHECK: br ^[[OUTER_BEFORE]](%[[INNER3]]#1 : f32) // CHECK: cf.br ^[[OUTER_BEFORE]](%[[INNER3]]#1 : f32)
scf.yield %4 : f32 scf.yield %4 : f32
} }
// CHECK: ^[[CONTINUATION]]: // CHECK: ^[[CONTINUATION]]:
@ -549,27 +549,27 @@ func @nested_while_ops(%arg0: f32) -> i64 {
// CHECK-LABEL: @ifs_in_parallel // CHECK-LABEL: @ifs_in_parallel
// CHECK: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: i1, %[[ARG4:.*]]: i1) // CHECK: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: i1, %[[ARG4:.*]]: i1)
func @ifs_in_parallel(%arg1: index, %arg2: index, %arg3: index, %arg4: i1, %arg5: i1) { func @ifs_in_parallel(%arg1: index, %arg2: index, %arg3: index, %arg4: i1, %arg5: i1) {
// CHECK: br ^[[LOOP_LATCH:.*]](%[[ARG0]] : index) // CHECK: cf.br ^[[LOOP_LATCH:.*]](%[[ARG0]] : index)
// CHECK: ^[[LOOP_LATCH]](%[[LOOP_IV:.*]]: index): // CHECK: ^[[LOOP_LATCH]](%[[LOOP_IV:.*]]: index):
// CHECK: %[[LOOP_COND:.*]] = arith.cmpi slt, %[[LOOP_IV]], %[[ARG1]] : index // CHECK: %[[LOOP_COND:.*]] = arith.cmpi slt, %[[LOOP_IV]], %[[ARG1]] : index
// CHECK: cond_br %[[LOOP_COND]], ^[[LOOP_BODY:.*]], ^[[LOOP_CONT:.*]] // CHECK: cf.cond_br %[[LOOP_COND]], ^[[LOOP_BODY:.*]], ^[[LOOP_CONT:.*]]
// CHECK: ^[[LOOP_BODY]]: // CHECK: ^[[LOOP_BODY]]:
// CHECK: cond_br %[[ARG3]], ^[[IF1_THEN:.*]], ^[[IF1_CONT:.*]] // CHECK: cf.cond_br %[[ARG3]], ^[[IF1_THEN:.*]], ^[[IF1_CONT:.*]]
// CHECK: ^[[IF1_THEN]]: // CHECK: ^[[IF1_THEN]]:
// CHECK: cond_br %[[ARG4]], ^[[IF2_THEN:.*]], ^[[IF2_ELSE:.*]] // CHECK: cf.cond_br %[[ARG4]], ^[[IF2_THEN:.*]], ^[[IF2_ELSE:.*]]
// CHECK: ^[[IF2_THEN]]: // CHECK: ^[[IF2_THEN]]:
// CHECK: %{{.*}} = "test.if2"() : () -> index // CHECK: %{{.*}} = "test.if2"() : () -> index
// CHECK: br ^[[IF2_MERGE:.*]](%{{.*}} : index) // CHECK: cf.br ^[[IF2_MERGE:.*]](%{{.*}} : index)
// CHECK: ^[[IF2_ELSE]]: // CHECK: ^[[IF2_ELSE]]:
// CHECK: %{{.*}} = "test.else2"() : () -> index // CHECK: %{{.*}} = "test.else2"() : () -> index
// CHECK: br ^[[IF2_MERGE]](%{{.*}} : index) // CHECK: cf.br ^[[IF2_MERGE]](%{{.*}} : index)
// CHECK: ^[[IF2_MERGE]](%{{.*}}: index): // CHECK: ^[[IF2_MERGE]](%{{.*}}: index):
// CHECK: br ^[[IF2_CONT:.*]] // CHECK: cf.br ^[[IF2_CONT:.*]]
// CHECK: ^[[IF2_CONT]]: // CHECK: ^[[IF2_CONT]]:
// CHECK: br ^[[IF1_CONT]] // CHECK: cf.br ^[[IF1_CONT]]
// CHECK: ^[[IF1_CONT]]: // CHECK: ^[[IF1_CONT]]:
// CHECK: %{{.*}} = arith.addi %[[LOOP_IV]], %[[ARG2]] : index // CHECK: %{{.*}} = arith.addi %[[LOOP_IV]], %[[ARG2]] : index
// CHECK: br ^[[LOOP_LATCH]](%{{.*}} : index) // CHECK: cf.br ^[[LOOP_LATCH]](%{{.*}} : index)
scf.parallel (%i) = (%arg1) to (%arg2) step (%arg3) { scf.parallel (%i) = (%arg1) to (%arg2) step (%arg3) {
scf.if %arg4 { scf.if %arg4 {
%0 = scf.if %arg5 -> (index) { %0 = scf.if %arg5 -> (index) {
@ -593,7 +593,7 @@ func @func_execute_region_elim_multi_yield() {
"test.foo"() : () -> () "test.foo"() : () -> ()
%v = scf.execute_region -> i64 { %v = scf.execute_region -> i64 {
%c = "test.cmp"() : () -> i1 %c = "test.cmp"() : () -> i1
cond_br %c, ^bb2, ^bb3 cf.cond_br %c, ^bb2, ^bb3
^bb2: ^bb2:
%x = "test.val1"() : () -> i64 %x = "test.val1"() : () -> i64
scf.yield %x : i64 scf.yield %x : i64
@ -607,16 +607,16 @@ func @func_execute_region_elim_multi_yield() {
// CHECK-NOT: execute_region // CHECK-NOT: execute_region
// CHECK: "test.foo" // CHECK: "test.foo"
// CHECK: br ^[[rentry:.+]] // CHECK: cf.br ^[[rentry:.+]]
// CHECK: ^[[rentry]] // CHECK: ^[[rentry]]
// CHECK: %[[cmp:.+]] = "test.cmp" // CHECK: %[[cmp:.+]] = "test.cmp"
// CHECK: cond_br %[[cmp]], ^[[bb1:.+]], ^[[bb2:.+]] // CHECK: cf.cond_br %[[cmp]], ^[[bb1:.+]], ^[[bb2:.+]]
// CHECK: ^[[bb1]]: // CHECK: ^[[bb1]]:
// CHECK: %[[x:.+]] = "test.val1" // CHECK: %[[x:.+]] = "test.val1"
// CHECK: br ^[[bb3:.+]](%[[x]] : i64) // CHECK: cf.br ^[[bb3:.+]](%[[x]] : i64)
// CHECK: ^[[bb2]]: // CHECK: ^[[bb2]]:
// CHECK: %[[y:.+]] = "test.val2" // CHECK: %[[y:.+]] = "test.val2"
// CHECK: br ^[[bb3]](%[[y:.+]] : i64) // CHECK: cf.br ^[[bb3]](%[[y:.+]] : i64)
// CHECK: ^[[bb3]](%[[z:.+]]: i64): // CHECK: ^[[bb3]](%[[z:.+]]: i64):
// CHECK: "test.bar"(%[[z]]) // CHECK: "test.bar"(%[[z]])
// CHECK: return // CHECK: return

View File

@ -6,7 +6,7 @@
// CHECK-SAME: %[[RHS:.*]]: tensor<?xindex>) -> !shape.witness { // CHECK-SAME: %[[RHS:.*]]: tensor<?xindex>) -> !shape.witness {
// CHECK: %[[RET:.*]] = shape.const_witness true // CHECK: %[[RET:.*]] = shape.const_witness true
// CHECK: %[[BROADCAST_IS_VALID:.*]] = shape.is_broadcastable %[[LHS]], %[[RHS]] // CHECK: %[[BROADCAST_IS_VALID:.*]] = shape.is_broadcastable %[[LHS]], %[[RHS]]
// CHECK: assert %[[BROADCAST_IS_VALID]], "required broadcastable shapes" // CHECK: cf.assert %[[BROADCAST_IS_VALID]], "required broadcastable shapes"
// CHECK: return %[[RET]] : !shape.witness // CHECK: return %[[RET]] : !shape.witness
// CHECK: } // CHECK: }
func @cstr_broadcastable(%arg0: tensor<?xindex>, %arg1: tensor<?xindex>) -> !shape.witness { func @cstr_broadcastable(%arg0: tensor<?xindex>, %arg1: tensor<?xindex>) -> !shape.witness {
@ -19,7 +19,7 @@ func @cstr_broadcastable(%arg0: tensor<?xindex>, %arg1: tensor<?xindex>) -> !sha
// CHECK-SAME: %[[RHS:.*]]: tensor<?xindex>) -> !shape.witness { // CHECK-SAME: %[[RHS:.*]]: tensor<?xindex>) -> !shape.witness {
// CHECK: %[[RET:.*]] = shape.const_witness true // CHECK: %[[RET:.*]] = shape.const_witness true
// CHECK: %[[EQUAL_IS_VALID:.*]] = shape.shape_eq %[[LHS]], %[[RHS]] // CHECK: %[[EQUAL_IS_VALID:.*]] = shape.shape_eq %[[LHS]], %[[RHS]]
// CHECK: assert %[[EQUAL_IS_VALID]], "required equal shapes" // CHECK: cf.assert %[[EQUAL_IS_VALID]], "required equal shapes"
// CHECK: return %[[RET]] : !shape.witness // CHECK: return %[[RET]] : !shape.witness
// CHECK: } // CHECK: }
func @cstr_eq(%arg0: tensor<?xindex>, %arg1: tensor<?xindex>) -> !shape.witness { func @cstr_eq(%arg0: tensor<?xindex>, %arg1: tensor<?xindex>) -> !shape.witness {
@ -30,7 +30,7 @@ func @cstr_eq(%arg0: tensor<?xindex>, %arg1: tensor<?xindex>) -> !shape.witness
// CHECK-LABEL: func @cstr_require // CHECK-LABEL: func @cstr_require
func @cstr_require(%arg0: i1) -> !shape.witness { func @cstr_require(%arg0: i1) -> !shape.witness {
// CHECK: %[[RET:.*]] = shape.const_witness true // CHECK: %[[RET:.*]] = shape.const_witness true
// CHECK: assert %arg0, "msg" // CHECK: cf.assert %arg0, "msg"
// CHECK: return %[[RET]] // CHECK: return %[[RET]]
%witness = shape.cstr_require %arg0, "msg" %witness = shape.cstr_require %arg0, "msg"
return %witness : !shape.witness return %witness : !shape.witness

View File

@ -29,7 +29,7 @@ func private @memref_call_conv_nested(%arg0: (memref<?xf32>) -> ())
//CHECK-LABEL: llvm.func @pass_through(%arg0: !llvm.ptr<func<void ()>>) -> !llvm.ptr<func<void ()>> { //CHECK-LABEL: llvm.func @pass_through(%arg0: !llvm.ptr<func<void ()>>) -> !llvm.ptr<func<void ()>> {
func @pass_through(%arg0: () -> ()) -> (() -> ()) { func @pass_through(%arg0: () -> ()) -> (() -> ()) {
// CHECK-NEXT: llvm.br ^bb1(%arg0 : !llvm.ptr<func<void ()>>) // CHECK-NEXT: llvm.br ^bb1(%arg0 : !llvm.ptr<func<void ()>>)
br ^bb1(%arg0 : () -> ()) cf.br ^bb1(%arg0 : () -> ())
//CHECK-NEXT: ^bb1(%0: !llvm.ptr<func<void ()>>): //CHECK-NEXT: ^bb1(%0: !llvm.ptr<func<void ()>>):
^bb1(%bbarg: () -> ()): ^bb1(%bbarg: () -> ()):

View File

@ -109,17 +109,17 @@ func @loop_carried(%arg0 : index, %arg1 : index, %arg2 : index, %base0 : !base_t
// This test checks that in the BAREPTR case, the branch arguments only forward the descriptor. // This test checks that in the BAREPTR case, the branch arguments only forward the descriptor.
// This test was lowered from a simple scf.for that swaps 2 memref iter_args. // This test was lowered from a simple scf.for that swaps 2 memref iter_args.
// BAREPTR: llvm.br ^bb1(%{{.*}}, %{{.*}}, %{{.*}} : i64, !llvm.struct<(ptr<i32, 201>, ptr<i32, 201>, i64, array<1 x i64>, array<1 x i64>)>, !llvm.struct<(ptr<i32, 201>, ptr<i32, 201>, i64, array<1 x i64>, array<1 x i64>)>) // BAREPTR: llvm.br ^bb1(%{{.*}}, %{{.*}}, %{{.*}} : i64, !llvm.struct<(ptr<i32, 201>, ptr<i32, 201>, i64, array<1 x i64>, array<1 x i64>)>, !llvm.struct<(ptr<i32, 201>, ptr<i32, 201>, i64, array<1 x i64>, array<1 x i64>)>)
br ^bb1(%arg0, %base0, %base1 : index, memref<64xi32, 201>, memref<64xi32, 201>) cf.br ^bb1(%arg0, %base0, %base1 : index, memref<64xi32, 201>, memref<64xi32, 201>)
// BAREPTR-NEXT: ^bb1 // BAREPTR-NEXT: ^bb1
// BAREPTR-NEXT: llvm.icmp // BAREPTR-NEXT: llvm.icmp
// BAREPTR-NEXT: llvm.cond_br %{{.*}}, ^bb2, ^bb3 // BAREPTR-NEXT: llvm.cond_br %{{.*}}, ^bb2, ^bb3
^bb1(%0: index, %1: memref<64xi32, 201>, %2: memref<64xi32, 201>): // 2 preds: ^bb0, ^bb2 ^bb1(%0: index, %1: memref<64xi32, 201>, %2: memref<64xi32, 201>): // 2 preds: ^bb0, ^bb2
%3 = arith.cmpi slt, %0, %arg1 : index %3 = arith.cmpi slt, %0, %arg1 : index
cond_br %3, ^bb2, ^bb3 cf.cond_br %3, ^bb2, ^bb3
^bb2: // pred: ^bb1 ^bb2: // pred: ^bb1
%4 = arith.addi %0, %arg2 : index %4 = arith.addi %0, %arg2 : index
br ^bb1(%4, %2, %1 : index, memref<64xi32, 201>, memref<64xi32, 201>) cf.br ^bb1(%4, %2, %1 : index, memref<64xi32, 201>, memref<64xi32, 201>)
^bb3: // pred: ^bb1 ^bb3: // pred: ^bb1
return %1, %2 : memref<64xi32, 201>, memref<64xi32, 201> return %1, %2 : memref<64xi32, 201>, memref<64xi32, 201>
} }

View File

@ -18,7 +18,7 @@ func @simple_loop() {
^bb0: ^bb0:
// CHECK-NEXT: llvm.br ^bb1 // CHECK-NEXT: llvm.br ^bb1
// CHECK32-NEXT: llvm.br ^bb1 // CHECK32-NEXT: llvm.br ^bb1
br ^bb1 cf.br ^bb1
// CHECK-NEXT: ^bb1: // pred: ^bb0 // CHECK-NEXT: ^bb1: // pred: ^bb0
// CHECK-NEXT: {{.*}} = llvm.mlir.constant(1 : index) : i64 // CHECK-NEXT: {{.*}} = llvm.mlir.constant(1 : index) : i64
@ -31,7 +31,7 @@ func @simple_loop() {
^bb1: // pred: ^bb0 ^bb1: // pred: ^bb0
%c1 = arith.constant 1 : index %c1 = arith.constant 1 : index
%c42 = arith.constant 42 : index %c42 = arith.constant 42 : index
br ^bb2(%c1 : index) cf.br ^bb2(%c1 : index)
// CHECK: ^bb2({{.*}}: i64): // 2 preds: ^bb1, ^bb3 // CHECK: ^bb2({{.*}}: i64): // 2 preds: ^bb1, ^bb3
// CHECK-NEXT: {{.*}} = llvm.icmp "slt" {{.*}}, {{.*}} : i64 // CHECK-NEXT: {{.*}} = llvm.icmp "slt" {{.*}}, {{.*}} : i64
@ -41,7 +41,7 @@ func @simple_loop() {
// CHECK32-NEXT: llvm.cond_br {{.*}}, ^bb3, ^bb4 // CHECK32-NEXT: llvm.cond_br {{.*}}, ^bb3, ^bb4
^bb2(%0: index): // 2 preds: ^bb1, ^bb3 ^bb2(%0: index): // 2 preds: ^bb1, ^bb3
%1 = arith.cmpi slt, %0, %c42 : index %1 = arith.cmpi slt, %0, %c42 : index
cond_br %1, ^bb3, ^bb4 cf.cond_br %1, ^bb3, ^bb4
// CHECK: ^bb3: // pred: ^bb2 // CHECK: ^bb3: // pred: ^bb2
// CHECK-NEXT: llvm.call @body({{.*}}) : (i64) -> () // CHECK-NEXT: llvm.call @body({{.*}}) : (i64) -> ()
@ -57,7 +57,7 @@ func @simple_loop() {
call @body(%0) : (index) -> () call @body(%0) : (index) -> ()
%c1_0 = arith.constant 1 : index %c1_0 = arith.constant 1 : index
%2 = arith.addi %0, %c1_0 : index %2 = arith.addi %0, %c1_0 : index
br ^bb2(%2 : index) cf.br ^bb2(%2 : index)
// CHECK: ^bb4: // pred: ^bb2 // CHECK: ^bb4: // pred: ^bb2
// CHECK-NEXT: llvm.return // CHECK-NEXT: llvm.return
@ -111,7 +111,7 @@ func private @other(index, i32) -> i32
func @func_args(i32, i32) -> i32 { func @func_args(i32, i32) -> i32 {
^bb0(%arg0: i32, %arg1: i32): ^bb0(%arg0: i32, %arg1: i32):
%c0_i32 = arith.constant 0 : i32 %c0_i32 = arith.constant 0 : i32
br ^bb1 cf.br ^bb1
// CHECK-NEXT: ^bb1: // pred: ^bb0 // CHECK-NEXT: ^bb1: // pred: ^bb0
// CHECK-NEXT: {{.*}} = llvm.mlir.constant(0 : index) : i64 // CHECK-NEXT: {{.*}} = llvm.mlir.constant(0 : index) : i64
@ -124,7 +124,7 @@ func @func_args(i32, i32) -> i32 {
^bb1: // pred: ^bb0 ^bb1: // pred: ^bb0
%c0 = arith.constant 0 : index %c0 = arith.constant 0 : index
%c42 = arith.constant 42 : index %c42 = arith.constant 42 : index
br ^bb2(%c0 : index) cf.br ^bb2(%c0 : index)
// CHECK-NEXT: ^bb2({{.*}}: i64): // 2 preds: ^bb1, ^bb3 // CHECK-NEXT: ^bb2({{.*}}: i64): // 2 preds: ^bb1, ^bb3
// CHECK-NEXT: {{.*}} = llvm.icmp "slt" {{.*}}, {{.*}} : i64 // CHECK-NEXT: {{.*}} = llvm.icmp "slt" {{.*}}, {{.*}} : i64
@ -134,7 +134,7 @@ func @func_args(i32, i32) -> i32 {
// CHECK32-NEXT: llvm.cond_br {{.*}}, ^bb3, ^bb4 // CHECK32-NEXT: llvm.cond_br {{.*}}, ^bb3, ^bb4
^bb2(%0: index): // 2 preds: ^bb1, ^bb3 ^bb2(%0: index): // 2 preds: ^bb1, ^bb3
%1 = arith.cmpi slt, %0, %c42 : index %1 = arith.cmpi slt, %0, %c42 : index
cond_br %1, ^bb3, ^bb4 cf.cond_br %1, ^bb3, ^bb4
// CHECK-NEXT: ^bb3: // pred: ^bb2 // CHECK-NEXT: ^bb3: // pred: ^bb2
// CHECK-NEXT: {{.*}} = llvm.call @body_args({{.*}}) : (i64) -> i64 // CHECK-NEXT: {{.*}} = llvm.call @body_args({{.*}}) : (i64) -> i64
@ -159,7 +159,7 @@ func @func_args(i32, i32) -> i32 {
%5 = call @other(%2, %arg1) : (index, i32) -> i32 %5 = call @other(%2, %arg1) : (index, i32) -> i32
%c1 = arith.constant 1 : index %c1 = arith.constant 1 : index
%6 = arith.addi %0, %c1 : index %6 = arith.addi %0, %c1 : index
br ^bb2(%6 : index) cf.br ^bb2(%6 : index)
// CHECK-NEXT: ^bb4: // pred: ^bb2 // CHECK-NEXT: ^bb4: // pred: ^bb2
// CHECK-NEXT: {{.*}} = llvm.mlir.constant(0 : index) : i64 // CHECK-NEXT: {{.*}} = llvm.mlir.constant(0 : index) : i64
@ -191,7 +191,7 @@ func private @post(index)
// CHECK-NEXT: llvm.br ^bb1 // CHECK-NEXT: llvm.br ^bb1
func @imperfectly_nested_loops() { func @imperfectly_nested_loops() {
^bb0: ^bb0:
br ^bb1 cf.br ^bb1
// CHECK-NEXT: ^bb1: // pred: ^bb0 // CHECK-NEXT: ^bb1: // pred: ^bb0
// CHECK-NEXT: {{.*}} = llvm.mlir.constant(0 : index) : i64 // CHECK-NEXT: {{.*}} = llvm.mlir.constant(0 : index) : i64
@ -200,21 +200,21 @@ func @imperfectly_nested_loops() {
^bb1: // pred: ^bb0 ^bb1: // pred: ^bb0
%c0 = arith.constant 0 : index %c0 = arith.constant 0 : index
%c42 = arith.constant 42 : index %c42 = arith.constant 42 : index
br ^bb2(%c0 : index) cf.br ^bb2(%c0 : index)
// CHECK-NEXT: ^bb2({{.*}}: i64): // 2 preds: ^bb1, ^bb7 // CHECK-NEXT: ^bb2({{.*}}: i64): // 2 preds: ^bb1, ^bb7
// CHECK-NEXT: {{.*}} = llvm.icmp "slt" {{.*}}, {{.*}} : i64 // CHECK-NEXT: {{.*}} = llvm.icmp "slt" {{.*}}, {{.*}} : i64
// CHECK-NEXT: llvm.cond_br {{.*}}, ^bb3, ^bb8 // CHECK-NEXT: llvm.cond_br {{.*}}, ^bb3, ^bb8
^bb2(%0: index): // 2 preds: ^bb1, ^bb7 ^bb2(%0: index): // 2 preds: ^bb1, ^bb7
%1 = arith.cmpi slt, %0, %c42 : index %1 = arith.cmpi slt, %0, %c42 : index
cond_br %1, ^bb3, ^bb8 cf.cond_br %1, ^bb3, ^bb8
// CHECK-NEXT: ^bb3: // CHECK-NEXT: ^bb3:
// CHECK-NEXT: llvm.call @pre({{.*}}) : (i64) -> () // CHECK-NEXT: llvm.call @pre({{.*}}) : (i64) -> ()
// CHECK-NEXT: llvm.br ^bb4 // CHECK-NEXT: llvm.br ^bb4
^bb3: // pred: ^bb2 ^bb3: // pred: ^bb2
call @pre(%0) : (index) -> () call @pre(%0) : (index) -> ()
br ^bb4 cf.br ^bb4
// CHECK-NEXT: ^bb4: // pred: ^bb3 // CHECK-NEXT: ^bb4: // pred: ^bb3
// CHECK-NEXT: {{.*}} = llvm.mlir.constant(7 : index) : i64 // CHECK-NEXT: {{.*}} = llvm.mlir.constant(7 : index) : i64
@ -223,14 +223,14 @@ func @imperfectly_nested_loops() {
^bb4: // pred: ^bb3 ^bb4: // pred: ^bb3
%c7 = arith.constant 7 : index %c7 = arith.constant 7 : index
%c56 = arith.constant 56 : index %c56 = arith.constant 56 : index
br ^bb5(%c7 : index) cf.br ^bb5(%c7 : index)
// CHECK-NEXT: ^bb5({{.*}}: i64): // 2 preds: ^bb4, ^bb6 // CHECK-NEXT: ^bb5({{.*}}: i64): // 2 preds: ^bb4, ^bb6
// CHECK-NEXT: {{.*}} = llvm.icmp "slt" {{.*}}, {{.*}} : i64 // CHECK-NEXT: {{.*}} = llvm.icmp "slt" {{.*}}, {{.*}} : i64
// CHECK-NEXT: llvm.cond_br {{.*}}, ^bb6, ^bb7 // CHECK-NEXT: llvm.cond_br {{.*}}, ^bb6, ^bb7
^bb5(%2: index): // 2 preds: ^bb4, ^bb6 ^bb5(%2: index): // 2 preds: ^bb4, ^bb6
%3 = arith.cmpi slt, %2, %c56 : index %3 = arith.cmpi slt, %2, %c56 : index
cond_br %3, ^bb6, ^bb7 cf.cond_br %3, ^bb6, ^bb7
// CHECK-NEXT: ^bb6: // pred: ^bb5 // CHECK-NEXT: ^bb6: // pred: ^bb5
// CHECK-NEXT: llvm.call @body2({{.*}}, {{.*}}) : (i64, i64) -> () // CHECK-NEXT: llvm.call @body2({{.*}}, {{.*}}) : (i64, i64) -> ()
@ -241,7 +241,7 @@ func @imperfectly_nested_loops() {
call @body2(%0, %2) : (index, index) -> () call @body2(%0, %2) : (index, index) -> ()
%c2 = arith.constant 2 : index %c2 = arith.constant 2 : index
%4 = arith.addi %2, %c2 : index %4 = arith.addi %2, %c2 : index
br ^bb5(%4 : index) cf.br ^bb5(%4 : index)
// CHECK-NEXT: ^bb7: // pred: ^bb5 // CHECK-NEXT: ^bb7: // pred: ^bb5
// CHECK-NEXT: llvm.call @post({{.*}}) : (i64) -> () // CHECK-NEXT: llvm.call @post({{.*}}) : (i64) -> ()
@ -252,7 +252,7 @@ func @imperfectly_nested_loops() {
call @post(%0) : (index) -> () call @post(%0) : (index) -> ()
%c1 = arith.constant 1 : index %c1 = arith.constant 1 : index
%5 = arith.addi %0, %c1 : index %5 = arith.addi %0, %c1 : index
br ^bb2(%5 : index) cf.br ^bb2(%5 : index)
// CHECK-NEXT: ^bb8: // pred: ^bb2 // CHECK-NEXT: ^bb8: // pred: ^bb2
// CHECK-NEXT: llvm.return // CHECK-NEXT: llvm.return
@ -316,49 +316,49 @@ func private @body3(index, index)
// CHECK-NEXT: } // CHECK-NEXT: }
func @more_imperfectly_nested_loops() { func @more_imperfectly_nested_loops() {
^bb0: ^bb0:
br ^bb1 cf.br ^bb1
^bb1: // pred: ^bb0 ^bb1: // pred: ^bb0
%c0 = arith.constant 0 : index %c0 = arith.constant 0 : index
%c42 = arith.constant 42 : index %c42 = arith.constant 42 : index
br ^bb2(%c0 : index) cf.br ^bb2(%c0 : index)
^bb2(%0: index): // 2 preds: ^bb1, ^bb11 ^bb2(%0: index): // 2 preds: ^bb1, ^bb11
%1 = arith.cmpi slt, %0, %c42 : index %1 = arith.cmpi slt, %0, %c42 : index
cond_br %1, ^bb3, ^bb12 cf.cond_br %1, ^bb3, ^bb12
^bb3: // pred: ^bb2 ^bb3: // pred: ^bb2
call @pre(%0) : (index) -> () call @pre(%0) : (index) -> ()
br ^bb4 cf.br ^bb4
^bb4: // pred: ^bb3 ^bb4: // pred: ^bb3
%c7 = arith.constant 7 : index %c7 = arith.constant 7 : index
%c56 = arith.constant 56 : index %c56 = arith.constant 56 : index
br ^bb5(%c7 : index) cf.br ^bb5(%c7 : index)
^bb5(%2: index): // 2 preds: ^bb4, ^bb6 ^bb5(%2: index): // 2 preds: ^bb4, ^bb6
%3 = arith.cmpi slt, %2, %c56 : index %3 = arith.cmpi slt, %2, %c56 : index
cond_br %3, ^bb6, ^bb7 cf.cond_br %3, ^bb6, ^bb7
^bb6: // pred: ^bb5 ^bb6: // pred: ^bb5
call @body2(%0, %2) : (index, index) -> () call @body2(%0, %2) : (index, index) -> ()
%c2 = arith.constant 2 : index %c2 = arith.constant 2 : index
%4 = arith.addi %2, %c2 : index %4 = arith.addi %2, %c2 : index
br ^bb5(%4 : index) cf.br ^bb5(%4 : index)
^bb7: // pred: ^bb5 ^bb7: // pred: ^bb5
call @mid(%0) : (index) -> () call @mid(%0) : (index) -> ()
br ^bb8 cf.br ^bb8
^bb8: // pred: ^bb7 ^bb8: // pred: ^bb7
%c18 = arith.constant 18 : index %c18 = arith.constant 18 : index
%c37 = arith.constant 37 : index %c37 = arith.constant 37 : index
br ^bb9(%c18 : index) cf.br ^bb9(%c18 : index)
^bb9(%5: index): // 2 preds: ^bb8, ^bb10 ^bb9(%5: index): // 2 preds: ^bb8, ^bb10
%6 = arith.cmpi slt, %5, %c37 : index %6 = arith.cmpi slt, %5, %c37 : index
cond_br %6, ^bb10, ^bb11 cf.cond_br %6, ^bb10, ^bb11
^bb10: // pred: ^bb9 ^bb10: // pred: ^bb9
call @body3(%0, %5) : (index, index) -> () call @body3(%0, %5) : (index, index) -> ()
%c3 = arith.constant 3 : index %c3 = arith.constant 3 : index
%7 = arith.addi %5, %c3 : index %7 = arith.addi %5, %c3 : index
br ^bb9(%7 : index) cf.br ^bb9(%7 : index)
^bb11: // pred: ^bb9 ^bb11: // pred: ^bb9
call @post(%0) : (index) -> () call @post(%0) : (index) -> ()
%c1 = arith.constant 1 : index %c1 = arith.constant 1 : index
%8 = arith.addi %0, %c1 : index %8 = arith.addi %0, %c1 : index
br ^bb2(%8 : index) cf.br ^bb2(%8 : index)
^bb12: // pred: ^bb2 ^bb12: // pred: ^bb2
return return
} }
@ -432,7 +432,7 @@ func @dfs_block_order(%arg0: i32) -> (i32) {
// CHECK-NEXT: %[[CST:.*]] = llvm.mlir.constant(42 : i32) : i32 // CHECK-NEXT: %[[CST:.*]] = llvm.mlir.constant(42 : i32) : i32
%0 = arith.constant 42 : i32 %0 = arith.constant 42 : i32
// CHECK-NEXT: llvm.br ^bb2 // CHECK-NEXT: llvm.br ^bb2
br ^bb2 cf.br ^bb2
// CHECK-NEXT: ^bb1: // CHECK-NEXT: ^bb1:
// CHECK-NEXT: %[[ADD:.*]] = llvm.add %arg0, %[[CST]] : i32 // CHECK-NEXT: %[[ADD:.*]] = llvm.add %arg0, %[[CST]] : i32
@ -444,7 +444,7 @@ func @dfs_block_order(%arg0: i32) -> (i32) {
// CHECK-NEXT: ^bb2: // CHECK-NEXT: ^bb2:
^bb2: ^bb2:
// CHECK-NEXT: llvm.br ^bb1 // CHECK-NEXT: llvm.br ^bb1
br ^bb1 cf.br ^bb1
} }
// ----- // -----
@ -469,7 +469,7 @@ func @floorf(%arg0 : f32) {
// ----- // -----
// Lowers `assert` to a function call to `abort` if the assertion is violated. // Lowers `cf.assert` to a function call to `abort` if the assertion is violated.
// CHECK: llvm.func @abort() // CHECK: llvm.func @abort()
// CHECK-LABEL: @assert_test_function // CHECK-LABEL: @assert_test_function
// CHECK-SAME: (%[[ARG:.*]]: i1) // CHECK-SAME: (%[[ARG:.*]]: i1)
@ -480,7 +480,7 @@ func @assert_test_function(%arg : i1) {
// CHECK: ^[[FAILURE_BLOCK]]: // CHECK: ^[[FAILURE_BLOCK]]:
// CHECK: llvm.call @abort() : () -> () // CHECK: llvm.call @abort() : () -> ()
// CHECK: llvm.unreachable // CHECK: llvm.unreachable
assert %arg, "Computer says no" cf.assert %arg, "Computer says no"
return return
} }
@ -514,8 +514,8 @@ func @fmaf(%arg0: f32, %arg1: vector<4xf32>) {
// CHECK-LABEL: func @switchi8( // CHECK-LABEL: func @switchi8(
func @switchi8(%arg0 : i8) -> i32 { func @switchi8(%arg0 : i8) -> i32 {
switch %arg0 : i8, [ cf.switch %arg0 : i8, [
default: ^bb1, default: ^bb1,
42: ^bb1, 42: ^bb1,
43: ^bb3 43: ^bb3
] ]

View File

@ -900,45 +900,3 @@ func @tensor_extract_constant(%a : index, %b: index, %c: index) -> i32 {
// CHECK: spv.ReturnValue %[[VAL]] // CHECK: spv.ReturnValue %[[VAL]]
return %extract : i32 return %extract : i32
} }
// -----
//===----------------------------------------------------------------------===//
// std.br, std.cond_br
//===----------------------------------------------------------------------===//
module attributes {
spv.target_env = #spv.target_env<#spv.vce<v1.0, [], []>, {}>
} {
// CHECK-LABEL: func @simple_loop
func @simple_loop(index, index, index) {
^bb0(%begin : index, %end : index, %step : index):
// CHECK-NEXT: spv.Branch ^bb1
br ^bb1
// CHECK-NEXT: ^bb1: // pred: ^bb0
// CHECK-NEXT: spv.Branch ^bb2({{.*}} : i32)
^bb1: // pred: ^bb0
br ^bb2(%begin : index)
// CHECK: ^bb2({{.*}}: i32): // 2 preds: ^bb1, ^bb3
// CHECK-NEXT: {{.*}} = spv.SLessThan {{.*}}, {{.*}} : i32
// CHECK-NEXT: spv.BranchConditional {{.*}}, ^bb3, ^bb4
^bb2(%0: index): // 2 preds: ^bb1, ^bb3
%1 = arith.cmpi slt, %0, %end : index
cond_br %1, ^bb3, ^bb4
// CHECK: ^bb3: // pred: ^bb2
// CHECK-NEXT: {{.*}} = spv.IAdd {{.*}}, {{.*}} : i32
// CHECK-NEXT: spv.Branch ^bb2({{.*}} : i32)
^bb3: // pred: ^bb2
%2 = arith.addi %0, %step : index
br ^bb2(%2 : index)
// CHECK: ^bb4: // pred: ^bb2
^bb4: // pred: ^bb2
return
}
}

View File

@ -56,9 +56,9 @@ func @affine_load_invalid_dim(%M : memref<10xi32>) {
^bb0(%arg: index): ^bb0(%arg: index):
affine.load %M[%arg] : memref<10xi32> affine.load %M[%arg] : memref<10xi32>
// expected-error@-1 {{index must be a dimension or symbol identifier}} // expected-error@-1 {{index must be a dimension or symbol identifier}}
br ^bb1 cf.br ^bb1
^bb1: ^bb1:
br ^bb1 cf.br ^bb1
}) : () -> () }) : () -> ()
return return
} }

View File

@ -54,13 +54,13 @@ func @token_value_to_func() {
// CHECK-LABEL: @token_arg_cond_br_await_with_fallthough // CHECK-LABEL: @token_arg_cond_br_await_with_fallthough
// CHECK: %[[TOKEN:.*]]: !async.token // CHECK: %[[TOKEN:.*]]: !async.token
func @token_arg_cond_br_await_with_fallthough(%arg0: !async.token, %arg1: i1) { func @token_arg_cond_br_await_with_fallthough(%arg0: !async.token, %arg1: i1) {
// CHECK: cond_br // CHECK: cf.cond_br
// CHECK-SAME: ^[[BB1:.*]], ^[[BB2:.*]] // CHECK-SAME: ^[[BB1:.*]], ^[[BB2:.*]]
cond_br %arg1, ^bb1, ^bb2 cf.cond_br %arg1, ^bb1, ^bb2
^bb1: ^bb1:
// CHECK: ^[[BB1]]: // CHECK: ^[[BB1]]:
// CHECK: br ^[[BB2]] // CHECK: cf.br ^[[BB2]]
br ^bb2 cf.br ^bb2
^bb2: ^bb2:
// CHECK: ^[[BB2]]: // CHECK: ^[[BB2]]:
// CHECK: async.runtime.await %[[TOKEN]] // CHECK: async.runtime.await %[[TOKEN]]
@ -88,10 +88,10 @@ func @token_coro_return() -> !async.token {
async.runtime.resume %hdl async.runtime.resume %hdl
async.coro.suspend %saved, ^suspend, ^resume, ^cleanup async.coro.suspend %saved, ^suspend, ^resume, ^cleanup
^resume: ^resume:
br ^cleanup cf.br ^cleanup
^cleanup: ^cleanup:
async.coro.free %id, %hdl async.coro.free %id, %hdl
br ^suspend cf.br ^suspend
^suspend: ^suspend:
async.coro.end %hdl async.coro.end %hdl
return %token : !async.token return %token : !async.token
@ -109,10 +109,10 @@ func @token_coro_await_and_resume(%arg0: !async.token) -> !async.token {
// CHECK-NEXT: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i64} // CHECK-NEXT: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i64}
async.coro.suspend %saved, ^suspend, ^resume, ^cleanup async.coro.suspend %saved, ^suspend, ^resume, ^cleanup
^resume: ^resume:
br ^cleanup cf.br ^cleanup
^cleanup: ^cleanup:
async.coro.free %id, %hdl async.coro.free %id, %hdl
br ^suspend cf.br ^suspend
^suspend: ^suspend:
async.coro.end %hdl async.coro.end %hdl
return %token : !async.token return %token : !async.token
@ -137,10 +137,10 @@ func @value_coro_await_and_resume(%arg0: !async.value<f32>) -> !async.token {
%0 = async.runtime.load %arg0 : !async.value<f32> %0 = async.runtime.load %arg0 : !async.value<f32>
// CHECK: arith.addf %[[LOADED]], %[[LOADED]] // CHECK: arith.addf %[[LOADED]], %[[LOADED]]
%1 = arith.addf %0, %0 : f32 %1 = arith.addf %0, %0 : f32
br ^cleanup cf.br ^cleanup
^cleanup: ^cleanup:
async.coro.free %id, %hdl async.coro.free %id, %hdl
br ^suspend cf.br ^suspend
^suspend: ^suspend:
async.coro.end %hdl async.coro.end %hdl
return %token : !async.token return %token : !async.token
@ -167,12 +167,12 @@ func private @outlined_async_execute(%arg0: !async.token) -> !async.token {
// CHECK: ^[[RESUME_1:.*]]: // CHECK: ^[[RESUME_1:.*]]:
// CHECK: async.runtime.set_available // CHECK: async.runtime.set_available
async.runtime.set_available %0 : !async.token async.runtime.set_available %0 : !async.token
br ^cleanup cf.br ^cleanup
^cleanup: ^cleanup:
// CHECK: ^[[CLEANUP:.*]]: // CHECK: ^[[CLEANUP:.*]]:
// CHECK: async.coro.free // CHECK: async.coro.free
async.coro.free %1, %2 async.coro.free %1, %2
br ^suspend cf.br ^suspend
^suspend: ^suspend:
// CHECK: ^[[SUSPEND:.*]]: // CHECK: ^[[SUSPEND:.*]]:
// CHECK: async.coro.end // CHECK: async.coro.end
@ -198,7 +198,7 @@ func @token_await_inside_nested_region(%arg0: i1) {
// CHECK-LABEL: @token_defined_in_the_loop // CHECK-LABEL: @token_defined_in_the_loop
func @token_defined_in_the_loop() { func @token_defined_in_the_loop() {
br ^bb1 cf.br ^bb1
^bb1: ^bb1:
// CHECK: ^[[BB1:.*]]: // CHECK: ^[[BB1:.*]]:
// CHECK: %[[TOKEN:.*]] = call @token() // CHECK: %[[TOKEN:.*]] = call @token()
@ -207,7 +207,7 @@ func @token_defined_in_the_loop() {
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i64} // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i64}
async.runtime.await %token : !async.token async.runtime.await %token : !async.token
%0 = call @cond(): () -> (i1) %0 = call @cond(): () -> (i1)
cond_br %0, ^bb1, ^bb2 cf.cond_br %0, ^bb1, ^bb2
^bb2: ^bb2:
// CHECK: ^[[BB2:.*]]: // CHECK: ^[[BB2:.*]]:
// CHECK: return // CHECK: return
@ -218,18 +218,18 @@ func @token_defined_in_the_loop() {
func @divergent_liveness_one_token(%arg0 : i1) { func @divergent_liveness_one_token(%arg0 : i1) {
// CHECK: %[[TOKEN:.*]] = call @token() // CHECK: %[[TOKEN:.*]] = call @token()
%token = call @token() : () -> !async.token %token = call @token() : () -> !async.token
// CHECK: cond_br %arg0, ^[[LIVE_IN:.*]], ^[[REF_COUNTING:.*]] // CHECK: cf.cond_br %arg0, ^[[LIVE_IN:.*]], ^[[REF_COUNTING:.*]]
cond_br %arg0, ^bb1, ^bb2 cf.cond_br %arg0, ^bb1, ^bb2
^bb1: ^bb1:
// CHECK: ^[[LIVE_IN]]: // CHECK: ^[[LIVE_IN]]:
// CHECK: async.runtime.await %[[TOKEN]] // CHECK: async.runtime.await %[[TOKEN]]
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i64} // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i64}
// CHECK: br ^[[RETURN:.*]] // CHECK: cf.br ^[[RETURN:.*]]
async.runtime.await %token : !async.token async.runtime.await %token : !async.token
br ^bb2 cf.br ^bb2
// CHECK: ^[[REF_COUNTING:.*]]: // CHECK: ^[[REF_COUNTING:.*]]:
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i64} // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i64}
// CHECK: br ^[[RETURN:.*]] // CHECK: cf.br ^[[RETURN:.*]]
^bb2: ^bb2:
// CHECK: ^[[RETURN]]: // CHECK: ^[[RETURN]]:
// CHECK: return // CHECK: return
@ -240,20 +240,20 @@ func @divergent_liveness_one_token(%arg0 : i1) {
func @divergent_liveness_unique_predecessor(%arg0 : i1) { func @divergent_liveness_unique_predecessor(%arg0 : i1) {
// CHECK: %[[TOKEN:.*]] = call @token() // CHECK: %[[TOKEN:.*]] = call @token()
%token = call @token() : () -> !async.token %token = call @token() : () -> !async.token
// CHECK: cond_br %arg0, ^[[LIVE_IN:.*]], ^[[NO_LIVE_IN:.*]] // CHECK: cf.cond_br %arg0, ^[[LIVE_IN:.*]], ^[[NO_LIVE_IN:.*]]
cond_br %arg0, ^bb2, ^bb1 cf.cond_br %arg0, ^bb2, ^bb1
^bb1: ^bb1:
// CHECK: ^[[NO_LIVE_IN]]: // CHECK: ^[[NO_LIVE_IN]]:
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i64} // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i64}
// CHECK: br ^[[RETURN:.*]] // CHECK: cf.br ^[[RETURN:.*]]
br ^bb3 cf.br ^bb3
^bb2: ^bb2:
// CHECK: ^[[LIVE_IN]]: // CHECK: ^[[LIVE_IN]]:
// CHECK: async.runtime.await %[[TOKEN]] // CHECK: async.runtime.await %[[TOKEN]]
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i64} // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i64}
// CHECK: br ^[[RETURN]] // CHECK: cf.br ^[[RETURN]]
async.runtime.await %token : !async.token async.runtime.await %token : !async.token
br ^bb3 cf.br ^bb3
^bb3: ^bb3:
// CHECK: ^[[RETURN]]: // CHECK: ^[[RETURN]]:
// CHECK: return // CHECK: return
@ -266,24 +266,24 @@ func @divergent_liveness_two_tokens(%arg0 : i1) {
// CHECK: %[[TOKEN1:.*]] = call @token() // CHECK: %[[TOKEN1:.*]] = call @token()
%token0 = call @token() : () -> !async.token %token0 = call @token() : () -> !async.token
%token1 = call @token() : () -> !async.token %token1 = call @token() : () -> !async.token
// CHECK: cond_br %arg0, ^[[AWAIT0:.*]], ^[[AWAIT1:.*]] // CHECK: cf.cond_br %arg0, ^[[AWAIT0:.*]], ^[[AWAIT1:.*]]
cond_br %arg0, ^await0, ^await1 cf.cond_br %arg0, ^await0, ^await1
^await0: ^await0:
// CHECK: ^[[AWAIT0]]: // CHECK: ^[[AWAIT0]]:
// CHECK: async.runtime.drop_ref %[[TOKEN1]] {count = 1 : i64} // CHECK: async.runtime.drop_ref %[[TOKEN1]] {count = 1 : i64}
// CHECK: async.runtime.await %[[TOKEN0]] // CHECK: async.runtime.await %[[TOKEN0]]
// CHECK: async.runtime.drop_ref %[[TOKEN0]] {count = 1 : i64} // CHECK: async.runtime.drop_ref %[[TOKEN0]] {count = 1 : i64}
// CHECK: br ^[[RETURN:.*]] // CHECK: cf.br ^[[RETURN:.*]]
async.runtime.await %token0 : !async.token async.runtime.await %token0 : !async.token
br ^ret cf.br ^ret
^await1: ^await1:
// CHECK: ^[[AWAIT1]]: // CHECK: ^[[AWAIT1]]:
// CHECK: async.runtime.drop_ref %[[TOKEN0]] {count = 1 : i64} // CHECK: async.runtime.drop_ref %[[TOKEN0]] {count = 1 : i64}
// CHECK: async.runtime.await %[[TOKEN1]] // CHECK: async.runtime.await %[[TOKEN1]]
// CHECK: async.runtime.drop_ref %[[TOKEN1]] {count = 1 : i64} // CHECK: async.runtime.drop_ref %[[TOKEN1]] {count = 1 : i64}
// CHECK: br ^[[RETURN]] // CHECK: cf.br ^[[RETURN]]
async.runtime.await %token1 : !async.token async.runtime.await %token1 : !async.token
br ^ret cf.br ^ret
^ret: ^ret:
// CHECK: ^[[RETURN]]: // CHECK: ^[[RETURN]]:
// CHECK: return // CHECK: return

View File

@ -10,7 +10,7 @@ func @simple_callee(%arg0: f32) -> (f32 {builtin.foo = "bar"}) {
// CHECK: %[[RETURNED_STORAGE:.*]] = async.runtime.create : !async.value<f32> // CHECK: %[[RETURNED_STORAGE:.*]] = async.runtime.create : !async.value<f32>
// CHECK: %[[ID:.*]] = async.coro.id // CHECK: %[[ID:.*]] = async.coro.id
// CHECK: %[[HDL:.*]] = async.coro.begin %[[ID]] // CHECK: %[[HDL:.*]] = async.coro.begin %[[ID]]
// CHECK: br ^[[ORIGINAL_ENTRY:.*]] // CHECK: cf.br ^[[ORIGINAL_ENTRY:.*]]
// CHECK ^[[ORIGINAL_ENTRY]]: // CHECK ^[[ORIGINAL_ENTRY]]:
// CHECK: %[[VAL:.*]] = arith.addf %[[ARG]], %[[ARG]] : f32 // CHECK: %[[VAL:.*]] = arith.addf %[[ARG]], %[[ARG]] : f32
%0 = arith.addf %arg0, %arg0 : f32 %0 = arith.addf %arg0, %arg0 : f32
@ -29,7 +29,7 @@ func @simple_callee(%arg0: f32) -> (f32 {builtin.foo = "bar"}) {
// CHECK: ^[[RESUME]]: // CHECK: ^[[RESUME]]:
// CHECK: %[[IS_ERROR:.*]] = async.runtime.is_error %[[VAL_STORAGE]] : !async.value<f32> // CHECK: %[[IS_ERROR:.*]] = async.runtime.is_error %[[VAL_STORAGE]] : !async.value<f32>
// CHECK: cond_br %[[IS_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_OK:.*]] // CHECK: cf.cond_br %[[IS_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_OK:.*]]
// CHECK: ^[[BRANCH_OK]]: // CHECK: ^[[BRANCH_OK]]:
// CHECK: %[[LOADED:.*]] = async.runtime.load %[[VAL_STORAGE]] : <f32> // CHECK: %[[LOADED:.*]] = async.runtime.load %[[VAL_STORAGE]] : <f32>
@ -37,19 +37,19 @@ func @simple_callee(%arg0: f32) -> (f32 {builtin.foo = "bar"}) {
// CHECK: async.runtime.store %[[RETURNED]], %[[RETURNED_STORAGE]] : <f32> // CHECK: async.runtime.store %[[RETURNED]], %[[RETURNED_STORAGE]] : <f32>
// CHECK: async.runtime.set_available %[[RETURNED_STORAGE]] // CHECK: async.runtime.set_available %[[RETURNED_STORAGE]]
// CHECK: async.runtime.set_available %[[TOKEN]] // CHECK: async.runtime.set_available %[[TOKEN]]
// CHECK: br ^[[CLEANUP]] // CHECK: cf.br ^[[CLEANUP]]
%3 = arith.mulf %arg0, %2 : f32 %3 = arith.mulf %arg0, %2 : f32
return %3: f32 return %3: f32
// CHECK: ^[[BRANCH_ERROR]]: // CHECK: ^[[BRANCH_ERROR]]:
// CHECK: async.runtime.set_error %[[TOKEN]] // CHECK: async.runtime.set_error %[[TOKEN]]
// CHECK: async.runtime.set_error %[[RETURNED_STORAGE]] // CHECK: async.runtime.set_error %[[RETURNED_STORAGE]]
// CHECK: br ^[[CLEANUP]] // CHECK: cf.br ^[[CLEANUP]]
// CHECK: ^[[CLEANUP]]: // CHECK: ^[[CLEANUP]]:
// CHECK: async.coro.free %[[ID]], %[[HDL]] // CHECK: async.coro.free %[[ID]], %[[HDL]]
// CHECK: br ^[[SUSPEND]] // CHECK: cf.br ^[[SUSPEND]]
// CHECK: ^[[SUSPEND]]: // CHECK: ^[[SUSPEND]]:
// CHECK: async.coro.end %[[HDL]] // CHECK: async.coro.end %[[HDL]]
@ -63,7 +63,7 @@ func @simple_caller() -> f32 {
// CHECK: %[[RETURNED_STORAGE:.*]] = async.runtime.create : !async.value<f32> // CHECK: %[[RETURNED_STORAGE:.*]] = async.runtime.create : !async.value<f32>
// CHECK: %[[ID:.*]] = async.coro.id // CHECK: %[[ID:.*]] = async.coro.id
// CHECK: %[[HDL:.*]] = async.coro.begin %[[ID]] // CHECK: %[[HDL:.*]] = async.coro.begin %[[ID]]
// CHECK: br ^[[ORIGINAL_ENTRY:.*]] // CHECK: cf.br ^[[ORIGINAL_ENTRY:.*]]
// CHECK ^[[ORIGINAL_ENTRY]]: // CHECK ^[[ORIGINAL_ENTRY]]:
// CHECK: %[[CONSTANT:.*]] = arith.constant // CHECK: %[[CONSTANT:.*]] = arith.constant
@ -77,28 +77,28 @@ func @simple_caller() -> f32 {
// CHECK: ^[[RESUME]]: // CHECK: ^[[RESUME]]:
// CHECK: %[[IS_TOKEN_ERROR:.*]] = async.runtime.is_error %[[RETURNED_TO_CALLER]]#0 : !async.token // CHECK: %[[IS_TOKEN_ERROR:.*]] = async.runtime.is_error %[[RETURNED_TO_CALLER]]#0 : !async.token
// CHECK: cond_br %[[IS_TOKEN_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_TOKEN_OK:.*]] // CHECK: cf.cond_br %[[IS_TOKEN_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_TOKEN_OK:.*]]
// CHECK: ^[[BRANCH_TOKEN_OK]]: // CHECK: ^[[BRANCH_TOKEN_OK]]:
// CHECK: %[[IS_VALUE_ERROR:.*]] = async.runtime.is_error %[[RETURNED_TO_CALLER]]#1 : !async.value<f32> // CHECK: %[[IS_VALUE_ERROR:.*]] = async.runtime.is_error %[[RETURNED_TO_CALLER]]#1 : !async.value<f32>
// CHECK: cond_br %[[IS_VALUE_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_VALUE_OK:.*]] // CHECK: cf.cond_br %[[IS_VALUE_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_VALUE_OK:.*]]
// CHECK: ^[[BRANCH_VALUE_OK]]: // CHECK: ^[[BRANCH_VALUE_OK]]:
// CHECK: %[[LOADED:.*]] = async.runtime.load %[[RETURNED_TO_CALLER]]#1 : <f32> // CHECK: %[[LOADED:.*]] = async.runtime.load %[[RETURNED_TO_CALLER]]#1 : <f32>
// CHECK: async.runtime.store %[[LOADED]], %[[RETURNED_STORAGE]] : <f32> // CHECK: async.runtime.store %[[LOADED]], %[[RETURNED_STORAGE]] : <f32>
// CHECK: async.runtime.set_available %[[RETURNED_STORAGE]] // CHECK: async.runtime.set_available %[[RETURNED_STORAGE]]
// CHECK: async.runtime.set_available %[[TOKEN]] // CHECK: async.runtime.set_available %[[TOKEN]]
// CHECK: br ^[[CLEANUP]] // CHECK: cf.br ^[[CLEANUP]]
return %r: f32 return %r: f32
// CHECK: ^[[BRANCH_ERROR]]: // CHECK: ^[[BRANCH_ERROR]]:
// CHECK: async.runtime.set_error %[[TOKEN]] // CHECK: async.runtime.set_error %[[TOKEN]]
// CHECK: async.runtime.set_error %[[RETURNED_STORAGE]] // CHECK: async.runtime.set_error %[[RETURNED_STORAGE]]
// CHECK: br ^[[CLEANUP]] // CHECK: cf.br ^[[CLEANUP]]
// CHECK: ^[[CLEANUP]]: // CHECK: ^[[CLEANUP]]:
// CHECK: async.coro.free %[[ID]], %[[HDL]] // CHECK: async.coro.free %[[ID]], %[[HDL]]
// CHECK: br ^[[SUSPEND]] // CHECK: cf.br ^[[SUSPEND]]
// CHECK: ^[[SUSPEND]]: // CHECK: ^[[SUSPEND]]:
// CHECK: async.coro.end %[[HDL]] // CHECK: async.coro.end %[[HDL]]
@ -112,7 +112,7 @@ func @double_caller() -> f32 {
// CHECK: %[[RETURNED_STORAGE:.*]] = async.runtime.create : !async.value<f32> // CHECK: %[[RETURNED_STORAGE:.*]] = async.runtime.create : !async.value<f32>
// CHECK: %[[ID:.*]] = async.coro.id // CHECK: %[[ID:.*]] = async.coro.id
// CHECK: %[[HDL:.*]] = async.coro.begin %[[ID]] // CHECK: %[[HDL:.*]] = async.coro.begin %[[ID]]
// CHECK: br ^[[ORIGINAL_ENTRY:.*]] // CHECK: cf.br ^[[ORIGINAL_ENTRY:.*]]
// CHECK ^[[ORIGINAL_ENTRY]]: // CHECK ^[[ORIGINAL_ENTRY]]:
// CHECK: %[[CONSTANT:.*]] = arith.constant // CHECK: %[[CONSTANT:.*]] = arith.constant
@ -126,11 +126,11 @@ func @double_caller() -> f32 {
// CHECK: ^[[RESUME_1]]: // CHECK: ^[[RESUME_1]]:
// CHECK: %[[IS_TOKEN_ERROR_1:.*]] = async.runtime.is_error %[[RETURNED_TO_CALLER_1]]#0 : !async.token // CHECK: %[[IS_TOKEN_ERROR_1:.*]] = async.runtime.is_error %[[RETURNED_TO_CALLER_1]]#0 : !async.token
// CHECK: cond_br %[[IS_TOKEN_ERROR_1]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_TOKEN_OK_1:.*]] // CHECK: cf.cond_br %[[IS_TOKEN_ERROR_1]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_TOKEN_OK_1:.*]]
// CHECK: ^[[BRANCH_TOKEN_OK_1]]: // CHECK: ^[[BRANCH_TOKEN_OK_1]]:
// CHECK: %[[IS_VALUE_ERROR_1:.*]] = async.runtime.is_error %[[RETURNED_TO_CALLER_1]]#1 : !async.value<f32> // CHECK: %[[IS_VALUE_ERROR_1:.*]] = async.runtime.is_error %[[RETURNED_TO_CALLER_1]]#1 : !async.value<f32>
// CHECK: cond_br %[[IS_VALUE_ERROR_1]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_VALUE_OK_1:.*]] // CHECK: cf.cond_br %[[IS_VALUE_ERROR_1]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_VALUE_OK_1:.*]]
// CHECK: ^[[BRANCH_VALUE_OK_1]]: // CHECK: ^[[BRANCH_VALUE_OK_1]]:
// CHECK: %[[LOADED_1:.*]] = async.runtime.load %[[RETURNED_TO_CALLER_1]]#1 : <f32> // CHECK: %[[LOADED_1:.*]] = async.runtime.load %[[RETURNED_TO_CALLER_1]]#1 : <f32>
@ -143,27 +143,27 @@ func @double_caller() -> f32 {
// CHECK: ^[[RESUME_2]]: // CHECK: ^[[RESUME_2]]:
// CHECK: %[[IS_TOKEN_ERROR_2:.*]] = async.runtime.is_error %[[RETURNED_TO_CALLER_2]]#0 : !async.token // CHECK: %[[IS_TOKEN_ERROR_2:.*]] = async.runtime.is_error %[[RETURNED_TO_CALLER_2]]#0 : !async.token
// CHECK: cond_br %[[IS_TOKEN_ERROR_2]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_TOKEN_OK_2:.*]] // CHECK: cf.cond_br %[[IS_TOKEN_ERROR_2]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_TOKEN_OK_2:.*]]
// CHECK: ^[[BRANCH_TOKEN_OK_2]]: // CHECK: ^[[BRANCH_TOKEN_OK_2]]:
// CHECK: %[[IS_VALUE_ERROR_2:.*]] = async.runtime.is_error %[[RETURNED_TO_CALLER_2]]#1 : !async.value<f32> // CHECK: %[[IS_VALUE_ERROR_2:.*]] = async.runtime.is_error %[[RETURNED_TO_CALLER_2]]#1 : !async.value<f32>
// CHECK: cond_br %[[IS_VALUE_ERROR_2]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_VALUE_OK_2:.*]] // CHECK: cf.cond_br %[[IS_VALUE_ERROR_2]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_VALUE_OK_2:.*]]
// CHECK: ^[[BRANCH_VALUE_OK_2]]: // CHECK: ^[[BRANCH_VALUE_OK_2]]:
// CHECK: %[[LOADED_2:.*]] = async.runtime.load %[[RETURNED_TO_CALLER_2]]#1 : <f32> // CHECK: %[[LOADED_2:.*]] = async.runtime.load %[[RETURNED_TO_CALLER_2]]#1 : <f32>
// CHECK: async.runtime.store %[[LOADED_2]], %[[RETURNED_STORAGE]] : <f32> // CHECK: async.runtime.store %[[LOADED_2]], %[[RETURNED_STORAGE]] : <f32>
// CHECK: async.runtime.set_available %[[RETURNED_STORAGE]] // CHECK: async.runtime.set_available %[[RETURNED_STORAGE]]
// CHECK: async.runtime.set_available %[[TOKEN]] // CHECK: async.runtime.set_available %[[TOKEN]]
// CHECK: br ^[[CLEANUP]] // CHECK: cf.br ^[[CLEANUP]]
return %s: f32 return %s: f32
// CHECK: ^[[BRANCH_ERROR]]: // CHECK: ^[[BRANCH_ERROR]]:
// CHECK: async.runtime.set_error %[[TOKEN]] // CHECK: async.runtime.set_error %[[TOKEN]]
// CHECK: async.runtime.set_error %[[RETURNED_STORAGE]] // CHECK: async.runtime.set_error %[[RETURNED_STORAGE]]
// CHECK: br ^[[CLEANUP]] // CHECK: cf.br ^[[CLEANUP]]
// CHECK: ^[[CLEANUP]]: // CHECK: ^[[CLEANUP]]:
// CHECK: async.coro.free %[[ID]], %[[HDL]] // CHECK: async.coro.free %[[ID]], %[[HDL]]
// CHECK: br ^[[SUSPEND]] // CHECK: cf.br ^[[SUSPEND]]
// CHECK: ^[[SUSPEND]]: // CHECK: ^[[SUSPEND]]:
// CHECK: async.coro.end %[[HDL]] // CHECK: async.coro.end %[[HDL]]
@ -184,7 +184,7 @@ func @recursive(%arg: !async.token) {
async.await %arg : !async.token async.await %arg : !async.token
// CHECK: ^[[RESUME_1]]: // CHECK: ^[[RESUME_1]]:
// CHECK: %[[IS_ERROR:.*]] = async.runtime.is_error %[[ARG]] : !async.token // CHECK: %[[IS_ERROR:.*]] = async.runtime.is_error %[[ARG]] : !async.token
// CHECK: cond_br %[[IS_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_OK:.*]] // CHECK: cf.cond_br %[[IS_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_OK:.*]]
// CHECK: ^[[BRANCH_OK]]: // CHECK: ^[[BRANCH_OK]]:
// CHECK: %[[GIVEN:.*]] = async.runtime.create : !async.token // CHECK: %[[GIVEN:.*]] = async.runtime.create : !async.token
@ -200,16 +200,16 @@ call @recursive(%r): (!async.token) -> ()
// CHECK: ^[[RESUME_2]]: // CHECK: ^[[RESUME_2]]:
// CHECK: async.runtime.set_available %[[TOKEN]] // CHECK: async.runtime.set_available %[[TOKEN]]
// CHECK: br ^[[CLEANUP]] // CHECK: cf.br ^[[CLEANUP]]
// CHECK: ^[[BRANCH_ERROR]]: // CHECK: ^[[BRANCH_ERROR]]:
// CHECK: async.runtime.set_error %[[TOKEN]] // CHECK: async.runtime.set_error %[[TOKEN]]
// CHECK: br ^[[CLEANUP]] // CHECK: cf.br ^[[CLEANUP]]
return return
// CHECK: ^[[CLEANUP]]: // CHECK: ^[[CLEANUP]]:
// CHECK: async.coro.free %[[ID]], %[[HDL]] // CHECK: async.coro.free %[[ID]], %[[HDL]]
// CHECK: br ^[[SUSPEND]] // CHECK: cf.br ^[[SUSPEND]]
// CHECK: ^[[SUSPEND]]: // CHECK: ^[[SUSPEND]]:
// CHECK: async.coro.end %[[HDL]] // CHECK: async.coro.end %[[HDL]]
@ -230,7 +230,7 @@ func @corecursive1(%arg: !async.token) {
async.await %arg : !async.token async.await %arg : !async.token
// CHECK: ^[[RESUME_1]]: // CHECK: ^[[RESUME_1]]:
// CHECK: %[[IS_ERROR:.*]] = async.runtime.is_error %[[ARG]] : !async.token // CHECK: %[[IS_ERROR:.*]] = async.runtime.is_error %[[ARG]] : !async.token
// CHECK: cond_br %[[IS_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_OK:.*]] // CHECK: cf.cond_br %[[IS_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_OK:.*]]
// CHECK: ^[[BRANCH_OK]]: // CHECK: ^[[BRANCH_OK]]:
// CHECK: %[[GIVEN:.*]] = async.runtime.create : !async.token // CHECK: %[[GIVEN:.*]] = async.runtime.create : !async.token
@ -246,16 +246,16 @@ call @corecursive2(%r): (!async.token) -> ()
// CHECK: ^[[RESUME_2]]: // CHECK: ^[[RESUME_2]]:
// CHECK: async.runtime.set_available %[[TOKEN]] // CHECK: async.runtime.set_available %[[TOKEN]]
// CHECK: br ^[[CLEANUP]] // CHECK: cf.br ^[[CLEANUP]]
// CHECK: ^[[BRANCH_ERROR]]: // CHECK: ^[[BRANCH_ERROR]]:
// CHECK: async.runtime.set_error %[[TOKEN]] // CHECK: async.runtime.set_error %[[TOKEN]]
// CHECK: br ^[[CLEANUP]] // CHECK: cf.br ^[[CLEANUP]]
return return
// CHECK: ^[[CLEANUP]]: // CHECK: ^[[CLEANUP]]:
// CHECK: async.coro.free %[[ID]], %[[HDL]] // CHECK: async.coro.free %[[ID]], %[[HDL]]
// CHECK: br ^[[SUSPEND]] // CHECK: cf.br ^[[SUSPEND]]
// CHECK: ^[[SUSPEND]]: // CHECK: ^[[SUSPEND]]:
// CHECK: async.coro.end %[[HDL]] // CHECK: async.coro.end %[[HDL]]
@ -276,7 +276,7 @@ func @corecursive2(%arg: !async.token) {
async.await %arg : !async.token async.await %arg : !async.token
// CHECK: ^[[RESUME_1]]: // CHECK: ^[[RESUME_1]]:
// CHECK: %[[IS_ERROR:.*]] = async.runtime.is_error %[[ARG]] : !async.token // CHECK: %[[IS_ERROR:.*]] = async.runtime.is_error %[[ARG]] : !async.token
// CHECK: cond_br %[[IS_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_OK:.*]] // CHECK: cf.cond_br %[[IS_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_OK:.*]]
// CHECK: ^[[BRANCH_OK]]: // CHECK: ^[[BRANCH_OK]]:
// CHECK: %[[GIVEN:.*]] = async.runtime.create : !async.token // CHECK: %[[GIVEN:.*]] = async.runtime.create : !async.token
@ -292,16 +292,16 @@ call @corecursive1(%r): (!async.token) -> ()
// CHECK: ^[[RESUME_2]]: // CHECK: ^[[RESUME_2]]:
// CHECK: async.runtime.set_available %[[TOKEN]] // CHECK: async.runtime.set_available %[[TOKEN]]
// CHECK: br ^[[CLEANUP]] // CHECK: cf.br ^[[CLEANUP]]
// CHECK: ^[[BRANCH_ERROR]]: // CHECK: ^[[BRANCH_ERROR]]:
// CHECK: async.runtime.set_error %[[TOKEN]] // CHECK: async.runtime.set_error %[[TOKEN]]
// CHECK: br ^[[CLEANUP]] // CHECK: cf.br ^[[CLEANUP]]
return return
// CHECK: ^[[CLEANUP]]: // CHECK: ^[[CLEANUP]]:
// CHECK: async.coro.free %[[ID]], %[[HDL]] // CHECK: async.coro.free %[[ID]], %[[HDL]]
// CHECK: br ^[[SUSPEND]] // CHECK: cf.br ^[[SUSPEND]]
// CHECK: ^[[SUSPEND]]: // CHECK: ^[[SUSPEND]]:
// CHECK: async.coro.end %[[HDL]] // CHECK: async.coro.end %[[HDL]]

View File

@ -63,7 +63,7 @@ func @nested_async_execute(%arg0: f32, %arg1: f32, %arg2: memref<1xf32>) {
// CHECK: %[[IS_ERROR:.*]] = async.runtime.is_error %[[TOKEN]] // CHECK: %[[IS_ERROR:.*]] = async.runtime.is_error %[[TOKEN]]
// CHECK: %[[TRUE:.*]] = arith.constant true // CHECK: %[[TRUE:.*]] = arith.constant true
// CHECK: %[[NOT_ERROR:.*]] = arith.xori %[[IS_ERROR]], %[[TRUE]] : i1 // CHECK: %[[NOT_ERROR:.*]] = arith.xori %[[IS_ERROR]], %[[TRUE]] : i1
// CHECK: assert %[[NOT_ERROR]] // CHECK: cf.assert %[[NOT_ERROR]]
// CHECK-NEXT: return // CHECK-NEXT: return
async.await %token0 : !async.token async.await %token0 : !async.token
return return
@ -109,7 +109,7 @@ func @nested_async_execute(%arg0: f32, %arg1: f32, %arg2: memref<1xf32>) {
// Check the error of the awaited token after resumption. // Check the error of the awaited token after resumption.
// CHECK: ^[[RESUME_1]]: // CHECK: ^[[RESUME_1]]:
// CHECK: %[[ERR:.*]] = async.runtime.is_error %[[INNER_TOKEN]] // CHECK: %[[ERR:.*]] = async.runtime.is_error %[[INNER_TOKEN]]
// CHECK: cond_br %[[ERR]], ^[[SET_ERROR:.*]], ^[[CONTINUATION:.*]] // CHECK: cf.cond_br %[[ERR]], ^[[SET_ERROR:.*]], ^[[CONTINUATION:.*]]
// Set token available if the token is not in the error state. // Set token available if the token is not in the error state.
// CHECK: ^[[CONTINUATION:.*]]: // CHECK: ^[[CONTINUATION:.*]]:
@ -169,7 +169,7 @@ func @async_execute_token_dependency(%arg0: f32, %arg1: memref<1xf32>) {
// Check the error of the awaited token after resumption. // Check the error of the awaited token after resumption.
// CHECK: ^[[RESUME_1]]: // CHECK: ^[[RESUME_1]]:
// CHECK: %[[ERR:.*]] = async.runtime.is_error %[[ARG0]] // CHECK: %[[ERR:.*]] = async.runtime.is_error %[[ARG0]]
// CHECK: cond_br %[[ERR]], ^[[SET_ERROR:.*]], ^[[CONTINUATION:.*]] // CHECK: cf.cond_br %[[ERR]], ^[[SET_ERROR:.*]], ^[[CONTINUATION:.*]]
// Emplace result token after second resumption and error checking. // Emplace result token after second resumption and error checking.
// CHECK: ^[[CONTINUATION:.*]]: // CHECK: ^[[CONTINUATION:.*]]:
@ -225,7 +225,7 @@ func @async_group_await_all(%arg0: f32, %arg1: memref<1xf32>) {
// Check the error of the awaited token after resumption. // Check the error of the awaited token after resumption.
// CHECK: ^[[RESUME_1]]: // CHECK: ^[[RESUME_1]]:
// CHECK: %[[ERR:.*]] = async.runtime.is_error %[[ARG]] // CHECK: %[[ERR:.*]] = async.runtime.is_error %[[ARG]]
// CHECK: cond_br %[[ERR]], ^[[SET_ERROR:.*]], ^[[CONTINUATION:.*]] // CHECK: cf.cond_br %[[ERR]], ^[[SET_ERROR:.*]], ^[[CONTINUATION:.*]]
// Emplace result token after error checking. // Emplace result token after error checking.
// CHECK: ^[[CONTINUATION:.*]]: // CHECK: ^[[CONTINUATION:.*]]:
@ -319,7 +319,7 @@ func @async_value_operands() {
// Check the error of the awaited token after resumption. // Check the error of the awaited token after resumption.
// CHECK: ^[[RESUME_1]]: // CHECK: ^[[RESUME_1]]:
// CHECK: %[[ERR:.*]] = async.runtime.is_error %[[ARG]] // CHECK: %[[ERR:.*]] = async.runtime.is_error %[[ARG]]
// CHECK: cond_br %[[ERR]], ^[[SET_ERROR:.*]], ^[[CONTINUATION:.*]] // CHECK: cf.cond_br %[[ERR]], ^[[SET_ERROR:.*]], ^[[CONTINUATION:.*]]
// // Load from the async.value argument after error checking. // // Load from the async.value argument after error checking.
// CHECK: ^[[CONTINUATION:.*]]: // CHECK: ^[[CONTINUATION:.*]]:
@ -335,7 +335,7 @@ func @async_value_operands() {
// CHECK-LABEL: @execute_assertion // CHECK-LABEL: @execute_assertion
func @execute_assertion(%arg0: i1) { func @execute_assertion(%arg0: i1) {
%token = async.execute { %token = async.execute {
assert %arg0, "error" cf.assert %arg0, "error"
async.yield async.yield
} }
async.await %token : !async.token async.await %token : !async.token
@ -358,17 +358,17 @@ func @execute_assertion(%arg0: i1) {
// Resume coroutine after suspension. // Resume coroutine after suspension.
// CHECK: ^[[RESUME]]: // CHECK: ^[[RESUME]]:
// CHECK: cond_br %[[ARG0]], ^[[SET_AVAILABLE:.*]], ^[[SET_ERROR:.*]] // CHECK: cf.cond_br %[[ARG0]], ^[[SET_AVAILABLE:.*]], ^[[SET_ERROR:.*]]
// Set coroutine completion token to available state. // Set coroutine completion token to available state.
// CHECK: ^[[SET_AVAILABLE]]: // CHECK: ^[[SET_AVAILABLE]]:
// CHECK: async.runtime.set_available %[[TOKEN]] // CHECK: async.runtime.set_available %[[TOKEN]]
// CHECK: br ^[[CLEANUP]] // CHECK: cf.br ^[[CLEANUP]]
// Set coroutine completion token to error state. // Set coroutine completion token to error state.
// CHECK: ^[[SET_ERROR]]: // CHECK: ^[[SET_ERROR]]:
// CHECK: async.runtime.set_error %[[TOKEN]] // CHECK: async.runtime.set_error %[[TOKEN]]
// CHECK: br ^[[CLEANUP]] // CHECK: cf.br ^[[CLEANUP]]
// Delete coroutine. // Delete coroutine.
// CHECK: ^[[CLEANUP]]: // CHECK: ^[[CLEANUP]]:
@ -409,7 +409,7 @@ func @lower_scf_to_cfg(%arg0: f32, %arg1: memref<1xf32>, %arg2: i1) {
// Check that structured control flow lowered to CFG. // Check that structured control flow lowered to CFG.
// CHECK-NOT: scf.if // CHECK-NOT: scf.if
// CHECK: cond_br %[[FLAG]] // CHECK: cf.cond_br %[[FLAG]]
// ----- // -----
// Constants captured by the async.execute region should be cloned into the // Constants captured by the async.execute region should be cloned into the

View File

@ -17,26 +17,26 @@
// CHECK-LABEL: func @condBranch // CHECK-LABEL: func @condBranch
func @condBranch(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { func @condBranch(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
cond_br %arg0, ^bb1, ^bb2 cf.cond_br %arg0, ^bb1, ^bb2
^bb1: ^bb1:
br ^bb3(%arg1 : memref<2xf32>) cf.br ^bb3(%arg1 : memref<2xf32>)
^bb2: ^bb2:
%0 = memref.alloc() : memref<2xf32> %0 = memref.alloc() : memref<2xf32>
test.buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>) test.buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>)
br ^bb3(%0 : memref<2xf32>) cf.br ^bb3(%0 : memref<2xf32>)
^bb3(%1: memref<2xf32>): ^bb3(%1: memref<2xf32>):
test.copy(%1, %arg2) : (memref<2xf32>, memref<2xf32>) test.copy(%1, %arg2) : (memref<2xf32>, memref<2xf32>)
return return
} }
// CHECK-NEXT: cond_br // CHECK-NEXT: cf.cond_br
// CHECK: %[[ALLOC0:.*]] = bufferization.clone // CHECK: %[[ALLOC0:.*]] = bufferization.clone
// CHECK-NEXT: br ^bb3(%[[ALLOC0]] // CHECK-NEXT: cf.br ^bb3(%[[ALLOC0]]
// CHECK: %[[ALLOC1:.*]] = memref.alloc // CHECK: %[[ALLOC1:.*]] = memref.alloc
// CHECK-NEXT: test.buffer_based // CHECK-NEXT: test.buffer_based
// CHECK-NEXT: %[[ALLOC2:.*]] = bufferization.clone %[[ALLOC1]] // CHECK-NEXT: %[[ALLOC2:.*]] = bufferization.clone %[[ALLOC1]]
// CHECK-NEXT: memref.dealloc %[[ALLOC1]] // CHECK-NEXT: memref.dealloc %[[ALLOC1]]
// CHECK-NEXT: br ^bb3(%[[ALLOC2]] // CHECK-NEXT: cf.br ^bb3(%[[ALLOC2]]
// CHECK: test.copy // CHECK: test.copy
// CHECK-NEXT: memref.dealloc // CHECK-NEXT: memref.dealloc
// CHECK-NEXT: return // CHECK-NEXT: return
@ -62,27 +62,27 @@ func @condBranchDynamicType(
%arg1: memref<?xf32>, %arg1: memref<?xf32>,
%arg2: memref<?xf32>, %arg2: memref<?xf32>,
%arg3: index) { %arg3: index) {
cond_br %arg0, ^bb1, ^bb2(%arg3: index) cf.cond_br %arg0, ^bb1, ^bb2(%arg3: index)
^bb1: ^bb1:
br ^bb3(%arg1 : memref<?xf32>) cf.br ^bb3(%arg1 : memref<?xf32>)
^bb2(%0: index): ^bb2(%0: index):
%1 = memref.alloc(%0) : memref<?xf32> %1 = memref.alloc(%0) : memref<?xf32>
test.buffer_based in(%arg1: memref<?xf32>) out(%1: memref<?xf32>) test.buffer_based in(%arg1: memref<?xf32>) out(%1: memref<?xf32>)
br ^bb3(%1 : memref<?xf32>) cf.br ^bb3(%1 : memref<?xf32>)
^bb3(%2: memref<?xf32>): ^bb3(%2: memref<?xf32>):
test.copy(%2, %arg2) : (memref<?xf32>, memref<?xf32>) test.copy(%2, %arg2) : (memref<?xf32>, memref<?xf32>)
return return
} }
// CHECK-NEXT: cond_br // CHECK-NEXT: cf.cond_br
// CHECK: %[[ALLOC0:.*]] = bufferization.clone // CHECK: %[[ALLOC0:.*]] = bufferization.clone
// CHECK-NEXT: br ^bb3(%[[ALLOC0]] // CHECK-NEXT: cf.br ^bb3(%[[ALLOC0]]
// CHECK: ^bb2(%[[IDX:.*]]:{{.*}}) // CHECK: ^bb2(%[[IDX:.*]]:{{.*}})
// CHECK-NEXT: %[[ALLOC1:.*]] = memref.alloc(%[[IDX]]) // CHECK-NEXT: %[[ALLOC1:.*]] = memref.alloc(%[[IDX]])
// CHECK-NEXT: test.buffer_based // CHECK-NEXT: test.buffer_based
// CHECK-NEXT: %[[ALLOC2:.*]] = bufferization.clone // CHECK-NEXT: %[[ALLOC2:.*]] = bufferization.clone
// CHECK-NEXT: memref.dealloc %[[ALLOC1]] // CHECK-NEXT: memref.dealloc %[[ALLOC1]]
// CHECK-NEXT: br ^bb3 // CHECK-NEXT: cf.br ^bb3
// CHECK-NEXT: ^bb3(%[[ALLOC3:.*]]:{{.*}}) // CHECK-NEXT: ^bb3(%[[ALLOC3:.*]]:{{.*}})
// CHECK: test.copy(%[[ALLOC3]], // CHECK: test.copy(%[[ALLOC3]],
// CHECK-NEXT: memref.dealloc %[[ALLOC3]] // CHECK-NEXT: memref.dealloc %[[ALLOC3]]
@ -98,28 +98,28 @@ func @condBranchUnrankedType(
%arg1: memref<*xf32>, %arg1: memref<*xf32>,
%arg2: memref<*xf32>, %arg2: memref<*xf32>,
%arg3: index) { %arg3: index) {
cond_br %arg0, ^bb1, ^bb2(%arg3: index) cf.cond_br %arg0, ^bb1, ^bb2(%arg3: index)
^bb1: ^bb1:
br ^bb3(%arg1 : memref<*xf32>) cf.br ^bb3(%arg1 : memref<*xf32>)
^bb2(%0: index): ^bb2(%0: index):
%1 = memref.alloc(%0) : memref<?xf32> %1 = memref.alloc(%0) : memref<?xf32>
%2 = memref.cast %1 : memref<?xf32> to memref<*xf32> %2 = memref.cast %1 : memref<?xf32> to memref<*xf32>
test.buffer_based in(%arg1: memref<*xf32>) out(%2: memref<*xf32>) test.buffer_based in(%arg1: memref<*xf32>) out(%2: memref<*xf32>)
br ^bb3(%2 : memref<*xf32>) cf.br ^bb3(%2 : memref<*xf32>)
^bb3(%3: memref<*xf32>): ^bb3(%3: memref<*xf32>):
test.copy(%3, %arg2) : (memref<*xf32>, memref<*xf32>) test.copy(%3, %arg2) : (memref<*xf32>, memref<*xf32>)
return return
} }
// CHECK-NEXT: cond_br // CHECK-NEXT: cf.cond_br
// CHECK: %[[ALLOC0:.*]] = bufferization.clone // CHECK: %[[ALLOC0:.*]] = bufferization.clone
// CHECK-NEXT: br ^bb3(%[[ALLOC0]] // CHECK-NEXT: cf.br ^bb3(%[[ALLOC0]]
// CHECK: ^bb2(%[[IDX:.*]]:{{.*}}) // CHECK: ^bb2(%[[IDX:.*]]:{{.*}})
// CHECK-NEXT: %[[ALLOC1:.*]] = memref.alloc(%[[IDX]]) // CHECK-NEXT: %[[ALLOC1:.*]] = memref.alloc(%[[IDX]])
// CHECK: test.buffer_based // CHECK: test.buffer_based
// CHECK-NEXT: %[[ALLOC2:.*]] = bufferization.clone // CHECK-NEXT: %[[ALLOC2:.*]] = bufferization.clone
// CHECK-NEXT: memref.dealloc %[[ALLOC1]] // CHECK-NEXT: memref.dealloc %[[ALLOC1]]
// CHECK-NEXT: br ^bb3 // CHECK-NEXT: cf.br ^bb3
// CHECK-NEXT: ^bb3(%[[ALLOC3:.*]]:{{.*}}) // CHECK-NEXT: ^bb3(%[[ALLOC3:.*]]:{{.*}})
// CHECK: test.copy(%[[ALLOC3]], // CHECK: test.copy(%[[ALLOC3]],
// CHECK-NEXT: memref.dealloc %[[ALLOC3]] // CHECK-NEXT: memref.dealloc %[[ALLOC3]]
@ -153,44 +153,44 @@ func @condBranchDynamicTypeNested(
%arg1: memref<?xf32>, %arg1: memref<?xf32>,
%arg2: memref<?xf32>, %arg2: memref<?xf32>,
%arg3: index) { %arg3: index) {
cond_br %arg0, ^bb1, ^bb2(%arg3: index) cf.cond_br %arg0, ^bb1, ^bb2(%arg3: index)
^bb1: ^bb1:
br ^bb6(%arg1 : memref<?xf32>) cf.br ^bb6(%arg1 : memref<?xf32>)
^bb2(%0: index): ^bb2(%0: index):
%1 = memref.alloc(%0) : memref<?xf32> %1 = memref.alloc(%0) : memref<?xf32>
test.buffer_based in(%arg1: memref<?xf32>) out(%1: memref<?xf32>) test.buffer_based in(%arg1: memref<?xf32>) out(%1: memref<?xf32>)
cond_br %arg0, ^bb3, ^bb4 cf.cond_br %arg0, ^bb3, ^bb4
^bb3: ^bb3:
br ^bb5(%1 : memref<?xf32>) cf.br ^bb5(%1 : memref<?xf32>)
^bb4: ^bb4:
br ^bb5(%1 : memref<?xf32>) cf.br ^bb5(%1 : memref<?xf32>)
^bb5(%2: memref<?xf32>): ^bb5(%2: memref<?xf32>):
br ^bb6(%2 : memref<?xf32>) cf.br ^bb6(%2 : memref<?xf32>)
^bb6(%3: memref<?xf32>): ^bb6(%3: memref<?xf32>):
br ^bb7(%3 : memref<?xf32>) cf.br ^bb7(%3 : memref<?xf32>)
^bb7(%4: memref<?xf32>): ^bb7(%4: memref<?xf32>):
test.copy(%4, %arg2) : (memref<?xf32>, memref<?xf32>) test.copy(%4, %arg2) : (memref<?xf32>, memref<?xf32>)
return return
} }
// CHECK-NEXT: cond_br{{.*}} // CHECK-NEXT: cf.cond_br{{.*}}
// CHECK-NEXT: ^bb1 // CHECK-NEXT: ^bb1
// CHECK-NEXT: %[[ALLOC0:.*]] = bufferization.clone // CHECK-NEXT: %[[ALLOC0:.*]] = bufferization.clone
// CHECK-NEXT: br ^bb6(%[[ALLOC0]] // CHECK-NEXT: cf.br ^bb6(%[[ALLOC0]]
// CHECK: ^bb2(%[[IDX:.*]]:{{.*}}) // CHECK: ^bb2(%[[IDX:.*]]:{{.*}})
// CHECK-NEXT: %[[ALLOC1:.*]] = memref.alloc(%[[IDX]]) // CHECK-NEXT: %[[ALLOC1:.*]] = memref.alloc(%[[IDX]])
// CHECK-NEXT: test.buffer_based // CHECK-NEXT: test.buffer_based
// CHECK: cond_br // CHECK: cf.cond_br
// CHECK: ^bb3: // CHECK: ^bb3:
// CHECK-NEXT: br ^bb5(%[[ALLOC1]]{{.*}}) // CHECK-NEXT: cf.br ^bb5(%[[ALLOC1]]{{.*}})
// CHECK: ^bb4: // CHECK: ^bb4:
// CHECK-NEXT: br ^bb5(%[[ALLOC1]]{{.*}}) // CHECK-NEXT: cf.br ^bb5(%[[ALLOC1]]{{.*}})
// CHECK-NEXT: ^bb5(%[[ALLOC2:.*]]:{{.*}}) // CHECK-NEXT: ^bb5(%[[ALLOC2:.*]]:{{.*}})
// CHECK-NEXT: %[[ALLOC3:.*]] = bufferization.clone %[[ALLOC2]] // CHECK-NEXT: %[[ALLOC3:.*]] = bufferization.clone %[[ALLOC2]]
// CHECK-NEXT: memref.dealloc %[[ALLOC1]] // CHECK-NEXT: memref.dealloc %[[ALLOC1]]
// CHECK-NEXT: br ^bb6(%[[ALLOC3]]{{.*}}) // CHECK-NEXT: cf.br ^bb6(%[[ALLOC3]]{{.*}})
// CHECK-NEXT: ^bb6(%[[ALLOC4:.*]]:{{.*}}) // CHECK-NEXT: ^bb6(%[[ALLOC4:.*]]:{{.*}})
// CHECK-NEXT: br ^bb7(%[[ALLOC4]]{{.*}}) // CHECK-NEXT: cf.br ^bb7(%[[ALLOC4]]{{.*}})
// CHECK-NEXT: ^bb7(%[[ALLOC5:.*]]:{{.*}}) // CHECK-NEXT: ^bb7(%[[ALLOC5:.*]]:{{.*}})
// CHECK: test.copy(%[[ALLOC5]], // CHECK: test.copy(%[[ALLOC5]],
// CHECK-NEXT: memref.dealloc %[[ALLOC4]] // CHECK-NEXT: memref.dealloc %[[ALLOC4]]
@ -225,18 +225,18 @@ func @emptyUsesValue(%arg0: memref<4xf32>) {
// CHECK-LABEL: func @criticalEdge // CHECK-LABEL: func @criticalEdge
func @criticalEdge(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { func @criticalEdge(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
cond_br %arg0, ^bb1, ^bb2(%arg1 : memref<2xf32>) cf.cond_br %arg0, ^bb1, ^bb2(%arg1 : memref<2xf32>)
^bb1: ^bb1:
%0 = memref.alloc() : memref<2xf32> %0 = memref.alloc() : memref<2xf32>
test.buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>) test.buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>)
br ^bb2(%0 : memref<2xf32>) cf.br ^bb2(%0 : memref<2xf32>)
^bb2(%1: memref<2xf32>): ^bb2(%1: memref<2xf32>):
test.copy(%1, %arg2) : (memref<2xf32>, memref<2xf32>) test.copy(%1, %arg2) : (memref<2xf32>, memref<2xf32>)
return return
} }
// CHECK-NEXT: %[[ALLOC0:.*]] = bufferization.clone // CHECK-NEXT: %[[ALLOC0:.*]] = bufferization.clone
// CHECK-NEXT: cond_br // CHECK-NEXT: cf.cond_br
// CHECK: %[[ALLOC1:.*]] = memref.alloc() // CHECK: %[[ALLOC1:.*]] = memref.alloc()
// CHECK-NEXT: test.buffer_based // CHECK-NEXT: test.buffer_based
// CHECK-NEXT: %[[ALLOC2:.*]] = bufferization.clone %[[ALLOC1]] // CHECK-NEXT: %[[ALLOC2:.*]] = bufferization.clone %[[ALLOC1]]
@ -260,9 +260,9 @@ func @criticalEdge(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
func @invCriticalEdge(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { func @invCriticalEdge(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
%0 = memref.alloc() : memref<2xf32> %0 = memref.alloc() : memref<2xf32>
test.buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>) test.buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>)
cond_br %arg0, ^bb1, ^bb2(%arg1 : memref<2xf32>) cf.cond_br %arg0, ^bb1, ^bb2(%arg1 : memref<2xf32>)
^bb1: ^bb1:
br ^bb2(%0 : memref<2xf32>) cf.br ^bb2(%0 : memref<2xf32>)
^bb2(%1: memref<2xf32>): ^bb2(%1: memref<2xf32>):
test.copy(%1, %arg2) : (memref<2xf32>, memref<2xf32>) test.copy(%1, %arg2) : (memref<2xf32>, memref<2xf32>)
return return
@ -288,13 +288,13 @@ func @invCriticalEdge(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
func @ifElse(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { func @ifElse(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
%0 = memref.alloc() : memref<2xf32> %0 = memref.alloc() : memref<2xf32>
test.buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>) test.buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>)
cond_br %arg0, cf.cond_br %arg0,
^bb1(%arg1, %0 : memref<2xf32>, memref<2xf32>), ^bb1(%arg1, %0 : memref<2xf32>, memref<2xf32>),
^bb2(%0, %arg1 : memref<2xf32>, memref<2xf32>) ^bb2(%0, %arg1 : memref<2xf32>, memref<2xf32>)
^bb1(%1: memref<2xf32>, %2: memref<2xf32>): ^bb1(%1: memref<2xf32>, %2: memref<2xf32>):
br ^bb3(%1, %2 : memref<2xf32>, memref<2xf32>) cf.br ^bb3(%1, %2 : memref<2xf32>, memref<2xf32>)
^bb2(%3: memref<2xf32>, %4: memref<2xf32>): ^bb2(%3: memref<2xf32>, %4: memref<2xf32>):
br ^bb3(%3, %4 : memref<2xf32>, memref<2xf32>) cf.br ^bb3(%3, %4 : memref<2xf32>, memref<2xf32>)
^bb3(%5: memref<2xf32>, %6: memref<2xf32>): ^bb3(%5: memref<2xf32>, %6: memref<2xf32>):
%7 = memref.alloc() : memref<2xf32> %7 = memref.alloc() : memref<2xf32>
test.buffer_based in(%5: memref<2xf32>) out(%7: memref<2xf32>) test.buffer_based in(%5: memref<2xf32>) out(%7: memref<2xf32>)
@ -326,13 +326,13 @@ func @ifElse(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
func @ifElseNoUsers(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { func @ifElseNoUsers(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
%0 = memref.alloc() : memref<2xf32> %0 = memref.alloc() : memref<2xf32>
test.buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>) test.buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>)
cond_br %arg0, cf.cond_br %arg0,
^bb1(%arg1, %0 : memref<2xf32>, memref<2xf32>), ^bb1(%arg1, %0 : memref<2xf32>, memref<2xf32>),
^bb2(%0, %arg1 : memref<2xf32>, memref<2xf32>) ^bb2(%0, %arg1 : memref<2xf32>, memref<2xf32>)
^bb1(%1: memref<2xf32>, %2: memref<2xf32>): ^bb1(%1: memref<2xf32>, %2: memref<2xf32>):
br ^bb3(%1, %2 : memref<2xf32>, memref<2xf32>) cf.br ^bb3(%1, %2 : memref<2xf32>, memref<2xf32>)
^bb2(%3: memref<2xf32>, %4: memref<2xf32>): ^bb2(%3: memref<2xf32>, %4: memref<2xf32>):
br ^bb3(%3, %4 : memref<2xf32>, memref<2xf32>) cf.br ^bb3(%3, %4 : memref<2xf32>, memref<2xf32>)
^bb3(%5: memref<2xf32>, %6: memref<2xf32>): ^bb3(%5: memref<2xf32>, %6: memref<2xf32>):
test.copy(%arg1, %arg2) : (memref<2xf32>, memref<2xf32>) test.copy(%arg1, %arg2) : (memref<2xf32>, memref<2xf32>)
return return
@ -361,17 +361,17 @@ func @ifElseNoUsers(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
func @ifElseNested(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { func @ifElseNested(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
%0 = memref.alloc() : memref<2xf32> %0 = memref.alloc() : memref<2xf32>
test.buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>) test.buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>)
cond_br %arg0, cf.cond_br %arg0,
^bb1(%arg1, %0 : memref<2xf32>, memref<2xf32>), ^bb1(%arg1, %0 : memref<2xf32>, memref<2xf32>),
^bb2(%0, %arg1 : memref<2xf32>, memref<2xf32>) ^bb2(%0, %arg1 : memref<2xf32>, memref<2xf32>)
^bb1(%1: memref<2xf32>, %2: memref<2xf32>): ^bb1(%1: memref<2xf32>, %2: memref<2xf32>):
br ^bb5(%1, %2 : memref<2xf32>, memref<2xf32>) cf.br ^bb5(%1, %2 : memref<2xf32>, memref<2xf32>)
^bb2(%3: memref<2xf32>, %4: memref<2xf32>): ^bb2(%3: memref<2xf32>, %4: memref<2xf32>):
cond_br %arg0, ^bb3(%3 : memref<2xf32>), ^bb4(%4 : memref<2xf32>) cf.cond_br %arg0, ^bb3(%3 : memref<2xf32>), ^bb4(%4 : memref<2xf32>)
^bb3(%5: memref<2xf32>): ^bb3(%5: memref<2xf32>):
br ^bb5(%5, %3 : memref<2xf32>, memref<2xf32>) cf.br ^bb5(%5, %3 : memref<2xf32>, memref<2xf32>)
^bb4(%6: memref<2xf32>): ^bb4(%6: memref<2xf32>):
br ^bb5(%3, %6 : memref<2xf32>, memref<2xf32>) cf.br ^bb5(%3, %6 : memref<2xf32>, memref<2xf32>)
^bb5(%7: memref<2xf32>, %8: memref<2xf32>): ^bb5(%7: memref<2xf32>, %8: memref<2xf32>):
%9 = memref.alloc() : memref<2xf32> %9 = memref.alloc() : memref<2xf32>
test.buffer_based in(%7: memref<2xf32>) out(%9: memref<2xf32>) test.buffer_based in(%7: memref<2xf32>) out(%9: memref<2xf32>)
@ -430,33 +430,33 @@ func @moving_alloc_and_inserting_missing_dealloc(
%cond: i1, %cond: i1,
%arg0: memref<2xf32>, %arg0: memref<2xf32>,
%arg1: memref<2xf32>) { %arg1: memref<2xf32>) {
cond_br %cond, ^bb1, ^bb2 cf.cond_br %cond, ^bb1, ^bb2
^bb1: ^bb1:
%0 = memref.alloc() : memref<2xf32> %0 = memref.alloc() : memref<2xf32>
test.buffer_based in(%arg0: memref<2xf32>) out(%0: memref<2xf32>) test.buffer_based in(%arg0: memref<2xf32>) out(%0: memref<2xf32>)
br ^exit(%0 : memref<2xf32>) cf.br ^exit(%0 : memref<2xf32>)
^bb2: ^bb2:
%1 = memref.alloc() : memref<2xf32> %1 = memref.alloc() : memref<2xf32>
test.buffer_based in(%arg0: memref<2xf32>) out(%1: memref<2xf32>) test.buffer_based in(%arg0: memref<2xf32>) out(%1: memref<2xf32>)
br ^exit(%1 : memref<2xf32>) cf.br ^exit(%1 : memref<2xf32>)
^exit(%arg2: memref<2xf32>): ^exit(%arg2: memref<2xf32>):
test.copy(%arg2, %arg1) : (memref<2xf32>, memref<2xf32>) test.copy(%arg2, %arg1) : (memref<2xf32>, memref<2xf32>)
return return
} }
// CHECK-NEXT: cond_br{{.*}} // CHECK-NEXT: cf.cond_br{{.*}}
// CHECK-NEXT: ^bb1 // CHECK-NEXT: ^bb1
// CHECK: %[[ALLOC0:.*]] = memref.alloc() // CHECK: %[[ALLOC0:.*]] = memref.alloc()
// CHECK-NEXT: test.buffer_based // CHECK-NEXT: test.buffer_based
// CHECK-NEXT: %[[ALLOC1:.*]] = bufferization.clone %[[ALLOC0]] // CHECK-NEXT: %[[ALLOC1:.*]] = bufferization.clone %[[ALLOC0]]
// CHECK-NEXT: memref.dealloc %[[ALLOC0]] // CHECK-NEXT: memref.dealloc %[[ALLOC0]]
// CHECK-NEXT: br ^bb3(%[[ALLOC1]] // CHECK-NEXT: cf.br ^bb3(%[[ALLOC1]]
// CHECK-NEXT: ^bb2 // CHECK-NEXT: ^bb2
// CHECK-NEXT: %[[ALLOC2:.*]] = memref.alloc() // CHECK-NEXT: %[[ALLOC2:.*]] = memref.alloc()
// CHECK-NEXT: test.buffer_based // CHECK-NEXT: test.buffer_based
// CHECK-NEXT: %[[ALLOC3:.*]] = bufferization.clone %[[ALLOC2]] // CHECK-NEXT: %[[ALLOC3:.*]] = bufferization.clone %[[ALLOC2]]
// CHECK-NEXT: memref.dealloc %[[ALLOC2]] // CHECK-NEXT: memref.dealloc %[[ALLOC2]]
// CHECK-NEXT: br ^bb3(%[[ALLOC3]] // CHECK-NEXT: cf.br ^bb3(%[[ALLOC3]]
// CHECK-NEXT: ^bb3(%[[ALLOC4:.*]]:{{.*}}) // CHECK-NEXT: ^bb3(%[[ALLOC4:.*]]:{{.*}})
// CHECK: test.copy // CHECK: test.copy
// CHECK-NEXT: memref.dealloc %[[ALLOC4]] // CHECK-NEXT: memref.dealloc %[[ALLOC4]]
@ -480,20 +480,20 @@ func @moving_invalid_dealloc_op_complex(
%arg0: memref<2xf32>, %arg0: memref<2xf32>,
%arg1: memref<2xf32>) { %arg1: memref<2xf32>) {
%1 = memref.alloc() : memref<2xf32> %1 = memref.alloc() : memref<2xf32>
cond_br %cond, ^bb1, ^bb2 cf.cond_br %cond, ^bb1, ^bb2
^bb1: ^bb1:
br ^exit(%arg0 : memref<2xf32>) cf.br ^exit(%arg0 : memref<2xf32>)
^bb2: ^bb2:
test.buffer_based in(%arg0: memref<2xf32>) out(%1: memref<2xf32>) test.buffer_based in(%arg0: memref<2xf32>) out(%1: memref<2xf32>)
memref.dealloc %1 : memref<2xf32> memref.dealloc %1 : memref<2xf32>
br ^exit(%1 : memref<2xf32>) cf.br ^exit(%1 : memref<2xf32>)
^exit(%arg2: memref<2xf32>): ^exit(%arg2: memref<2xf32>):
test.copy(%arg2, %arg1) : (memref<2xf32>, memref<2xf32>) test.copy(%arg2, %arg1) : (memref<2xf32>, memref<2xf32>)
return return
} }
// CHECK-NEXT: %[[ALLOC0:.*]] = memref.alloc() // CHECK-NEXT: %[[ALLOC0:.*]] = memref.alloc()
// CHECK-NEXT: cond_br // CHECK-NEXT: cf.cond_br
// CHECK: test.copy // CHECK: test.copy
// CHECK-NEXT: memref.dealloc %[[ALLOC0]] // CHECK-NEXT: memref.dealloc %[[ALLOC0]]
// CHECK-NEXT: return // CHECK-NEXT: return
@ -548,9 +548,9 @@ func @nested_regions_and_cond_branch(
%arg0: i1, %arg0: i1,
%arg1: memref<2xf32>, %arg1: memref<2xf32>,
%arg2: memref<2xf32>) { %arg2: memref<2xf32>) {
cond_br %arg0, ^bb1, ^bb2 cf.cond_br %arg0, ^bb1, ^bb2
^bb1: ^bb1:
br ^bb3(%arg1 : memref<2xf32>) cf.br ^bb3(%arg1 : memref<2xf32>)
^bb2: ^bb2:
%0 = memref.alloc() : memref<2xf32> %0 = memref.alloc() : memref<2xf32>
test.region_buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>) { test.region_buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>) {
@ -560,13 +560,13 @@ func @nested_regions_and_cond_branch(
%tmp1 = math.exp %gen1_arg0 : f32 %tmp1 = math.exp %gen1_arg0 : f32
test.region_yield %tmp1 : f32 test.region_yield %tmp1 : f32
} }
br ^bb3(%0 : memref<2xf32>) cf.br ^bb3(%0 : memref<2xf32>)
^bb3(%1: memref<2xf32>): ^bb3(%1: memref<2xf32>):
test.copy(%1, %arg2) : (memref<2xf32>, memref<2xf32>) test.copy(%1, %arg2) : (memref<2xf32>, memref<2xf32>)
return return
} }
// CHECK: (%[[cond:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %{{.*}}: {{.*}}) // CHECK: (%[[cond:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %{{.*}}: {{.*}})
// CHECK-NEXT: cond_br %[[cond]], ^[[BB1:.*]], ^[[BB2:.*]] // CHECK-NEXT: cf.cond_br %[[cond]], ^[[BB1:.*]], ^[[BB2:.*]]
// CHECK: %[[ALLOC0:.*]] = bufferization.clone %[[ARG1]] // CHECK: %[[ALLOC0:.*]] = bufferization.clone %[[ARG1]]
// CHECK: ^[[BB2]]: // CHECK: ^[[BB2]]:
// CHECK: %[[ALLOC1:.*]] = memref.alloc() // CHECK: %[[ALLOC1:.*]] = memref.alloc()
@ -728,21 +728,21 @@ func @subview(%arg0 : index, %arg1 : index, %arg2 : memref<?x?xf32>) {
// CHECK-LABEL: func @condBranchAlloca // CHECK-LABEL: func @condBranchAlloca
func @condBranchAlloca(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { func @condBranchAlloca(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
cond_br %arg0, ^bb1, ^bb2 cf.cond_br %arg0, ^bb1, ^bb2
^bb1: ^bb1:
br ^bb3(%arg1 : memref<2xf32>) cf.br ^bb3(%arg1 : memref<2xf32>)
^bb2: ^bb2:
%0 = memref.alloca() : memref<2xf32> %0 = memref.alloca() : memref<2xf32>
test.buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>) test.buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>)
br ^bb3(%0 : memref<2xf32>) cf.br ^bb3(%0 : memref<2xf32>)
^bb3(%1: memref<2xf32>): ^bb3(%1: memref<2xf32>):
test.copy(%1, %arg2) : (memref<2xf32>, memref<2xf32>) test.copy(%1, %arg2) : (memref<2xf32>, memref<2xf32>)
return return
} }
// CHECK-NEXT: cond_br // CHECK-NEXT: cf.cond_br
// CHECK: %[[ALLOCA:.*]] = memref.alloca() // CHECK: %[[ALLOCA:.*]] = memref.alloca()
// CHECK: br ^bb3(%[[ALLOCA:.*]]) // CHECK: cf.br ^bb3(%[[ALLOCA:.*]])
// CHECK-NEXT: ^bb3 // CHECK-NEXT: ^bb3
// CHECK-NEXT: test.copy // CHECK-NEXT: test.copy
// CHECK-NEXT: return // CHECK-NEXT: return
@ -757,13 +757,13 @@ func @condBranchAlloca(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
func @ifElseAlloca(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { func @ifElseAlloca(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
%0 = memref.alloc() : memref<2xf32> %0 = memref.alloc() : memref<2xf32>
test.buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>) test.buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>)
cond_br %arg0, cf.cond_br %arg0,
^bb1(%arg1, %0 : memref<2xf32>, memref<2xf32>), ^bb1(%arg1, %0 : memref<2xf32>, memref<2xf32>),
^bb2(%0, %arg1 : memref<2xf32>, memref<2xf32>) ^bb2(%0, %arg1 : memref<2xf32>, memref<2xf32>)
^bb1(%1: memref<2xf32>, %2: memref<2xf32>): ^bb1(%1: memref<2xf32>, %2: memref<2xf32>):
br ^bb3(%1, %2 : memref<2xf32>, memref<2xf32>) cf.br ^bb3(%1, %2 : memref<2xf32>, memref<2xf32>)
^bb2(%3: memref<2xf32>, %4: memref<2xf32>): ^bb2(%3: memref<2xf32>, %4: memref<2xf32>):
br ^bb3(%3, %4 : memref<2xf32>, memref<2xf32>) cf.br ^bb3(%3, %4 : memref<2xf32>, memref<2xf32>)
^bb3(%5: memref<2xf32>, %6: memref<2xf32>): ^bb3(%5: memref<2xf32>, %6: memref<2xf32>):
%7 = memref.alloca() : memref<2xf32> %7 = memref.alloca() : memref<2xf32>
test.buffer_based in(%5: memref<2xf32>) out(%7: memref<2xf32>) test.buffer_based in(%5: memref<2xf32>) out(%7: memref<2xf32>)
@ -788,17 +788,17 @@ func @ifElseNestedAlloca(
%arg2: memref<2xf32>) { %arg2: memref<2xf32>) {
%0 = memref.alloca() : memref<2xf32> %0 = memref.alloca() : memref<2xf32>
test.buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>) test.buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>)
cond_br %arg0, cf.cond_br %arg0,
^bb1(%arg1, %0 : memref<2xf32>, memref<2xf32>), ^bb1(%arg1, %0 : memref<2xf32>, memref<2xf32>),
^bb2(%0, %arg1 : memref<2xf32>, memref<2xf32>) ^bb2(%0, %arg1 : memref<2xf32>, memref<2xf32>)
^bb1(%1: memref<2xf32>, %2: memref<2xf32>): ^bb1(%1: memref<2xf32>, %2: memref<2xf32>):
br ^bb5(%1, %2 : memref<2xf32>, memref<2xf32>) cf.br ^bb5(%1, %2 : memref<2xf32>, memref<2xf32>)
^bb2(%3: memref<2xf32>, %4: memref<2xf32>): ^bb2(%3: memref<2xf32>, %4: memref<2xf32>):
cond_br %arg0, ^bb3(%3 : memref<2xf32>), ^bb4(%4 : memref<2xf32>) cf.cond_br %arg0, ^bb3(%3 : memref<2xf32>), ^bb4(%4 : memref<2xf32>)
^bb3(%5: memref<2xf32>): ^bb3(%5: memref<2xf32>):
br ^bb5(%5, %3 : memref<2xf32>, memref<2xf32>) cf.br ^bb5(%5, %3 : memref<2xf32>, memref<2xf32>)
^bb4(%6: memref<2xf32>): ^bb4(%6: memref<2xf32>):
br ^bb5(%3, %6 : memref<2xf32>, memref<2xf32>) cf.br ^bb5(%3, %6 : memref<2xf32>, memref<2xf32>)
^bb5(%7: memref<2xf32>, %8: memref<2xf32>): ^bb5(%7: memref<2xf32>, %8: memref<2xf32>):
%9 = memref.alloc() : memref<2xf32> %9 = memref.alloc() : memref<2xf32>
test.buffer_based in(%7: memref<2xf32>) out(%9: memref<2xf32>) test.buffer_based in(%7: memref<2xf32>) out(%9: memref<2xf32>)
@ -821,9 +821,9 @@ func @nestedRegionsAndCondBranchAlloca(
%arg0: i1, %arg0: i1,
%arg1: memref<2xf32>, %arg1: memref<2xf32>,
%arg2: memref<2xf32>) { %arg2: memref<2xf32>) {
cond_br %arg0, ^bb1, ^bb2 cf.cond_br %arg0, ^bb1, ^bb2
^bb1: ^bb1:
br ^bb3(%arg1 : memref<2xf32>) cf.br ^bb3(%arg1 : memref<2xf32>)
^bb2: ^bb2:
%0 = memref.alloc() : memref<2xf32> %0 = memref.alloc() : memref<2xf32>
test.region_buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>) { test.region_buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>) {
@ -833,13 +833,13 @@ func @nestedRegionsAndCondBranchAlloca(
%tmp1 = math.exp %gen1_arg0 : f32 %tmp1 = math.exp %gen1_arg0 : f32
test.region_yield %tmp1 : f32 test.region_yield %tmp1 : f32
} }
br ^bb3(%0 : memref<2xf32>) cf.br ^bb3(%0 : memref<2xf32>)
^bb3(%1: memref<2xf32>): ^bb3(%1: memref<2xf32>):
test.copy(%1, %arg2) : (memref<2xf32>, memref<2xf32>) test.copy(%1, %arg2) : (memref<2xf32>, memref<2xf32>)
return return
} }
// CHECK: (%[[cond:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %{{.*}}: {{.*}}) // CHECK: (%[[cond:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %{{.*}}: {{.*}})
// CHECK-NEXT: cond_br %[[cond]], ^[[BB1:.*]], ^[[BB2:.*]] // CHECK-NEXT: cf.cond_br %[[cond]], ^[[BB1:.*]], ^[[BB2:.*]]
// CHECK: ^[[BB1]]: // CHECK: ^[[BB1]]:
// CHECK: %[[ALLOC0:.*]] = bufferization.clone // CHECK: %[[ALLOC0:.*]] = bufferization.clone
// CHECK: ^[[BB2]]: // CHECK: ^[[BB2]]:
@ -1103,11 +1103,11 @@ func @loop_dynalloc(
%arg2: memref<?xf32>, %arg2: memref<?xf32>,
%arg3: memref<?xf32>) { %arg3: memref<?xf32>) {
%const0 = arith.constant 0 : i32 %const0 = arith.constant 0 : i32
br ^loopHeader(%const0, %arg2 : i32, memref<?xf32>) cf.br ^loopHeader(%const0, %arg2 : i32, memref<?xf32>)
^loopHeader(%i : i32, %buff : memref<?xf32>): ^loopHeader(%i : i32, %buff : memref<?xf32>):
%lessThan = arith.cmpi slt, %i, %arg1 : i32 %lessThan = arith.cmpi slt, %i, %arg1 : i32
cond_br %lessThan, cf.cond_br %lessThan,
^loopBody(%i, %buff : i32, memref<?xf32>), ^loopBody(%i, %buff : i32, memref<?xf32>),
^exit(%buff : memref<?xf32>) ^exit(%buff : memref<?xf32>)
@ -1116,7 +1116,7 @@ func @loop_dynalloc(
%inc = arith.addi %val, %const1 : i32 %inc = arith.addi %val, %const1 : i32
%size = arith.index_cast %inc : i32 to index %size = arith.index_cast %inc : i32 to index
%alloc1 = memref.alloc(%size) : memref<?xf32> %alloc1 = memref.alloc(%size) : memref<?xf32>
br ^loopHeader(%inc, %alloc1 : i32, memref<?xf32>) cf.br ^loopHeader(%inc, %alloc1 : i32, memref<?xf32>)
^exit(%buff3 : memref<?xf32>): ^exit(%buff3 : memref<?xf32>):
test.copy(%buff3, %arg3) : (memref<?xf32>, memref<?xf32>) test.copy(%buff3, %arg3) : (memref<?xf32>, memref<?xf32>)
@ -1136,17 +1136,17 @@ func @do_loop_alloc(
%arg2: memref<2xf32>, %arg2: memref<2xf32>,
%arg3: memref<2xf32>) { %arg3: memref<2xf32>) {
%const0 = arith.constant 0 : i32 %const0 = arith.constant 0 : i32
br ^loopBody(%const0, %arg2 : i32, memref<2xf32>) cf.br ^loopBody(%const0, %arg2 : i32, memref<2xf32>)
^loopBody(%val : i32, %buff2: memref<2xf32>): ^loopBody(%val : i32, %buff2: memref<2xf32>):
%const1 = arith.constant 1 : i32 %const1 = arith.constant 1 : i32
%inc = arith.addi %val, %const1 : i32 %inc = arith.addi %val, %const1 : i32
%alloc1 = memref.alloc() : memref<2xf32> %alloc1 = memref.alloc() : memref<2xf32>
br ^loopHeader(%inc, %alloc1 : i32, memref<2xf32>) cf.br ^loopHeader(%inc, %alloc1 : i32, memref<2xf32>)
^loopHeader(%i : i32, %buff : memref<2xf32>): ^loopHeader(%i : i32, %buff : memref<2xf32>):
%lessThan = arith.cmpi slt, %i, %arg1 : i32 %lessThan = arith.cmpi slt, %i, %arg1 : i32
cond_br %lessThan, cf.cond_br %lessThan,
^loopBody(%i, %buff : i32, memref<2xf32>), ^loopBody(%i, %buff : i32, memref<2xf32>),
^exit(%buff : memref<2xf32>) ^exit(%buff : memref<2xf32>)

Some files were not shown because too many files have changed in this diff Show More