forked from OSchip/llvm-project
[mlir] Split out a new ControlFlow dialect from Standard
This dialect is intended to model lower level/branch based control-flow constructs. The initial set of operations are: AssertOp, BranchOp, CondBranchOp, SwitchOp; all split out from the current standard dialect. See https://discourse.llvm.org/t/standard-dialect-the-final-chapter/6061 Differential Revision: https://reviews.llvm.org/D118966
This commit is contained in:
parent
edca177cbe
commit
ace01605e0
|
@ -27,8 +27,8 @@ namespace fir::support {
|
|||
#define FLANG_NONCODEGEN_DIALECT_LIST \
|
||||
mlir::AffineDialect, FIROpsDialect, mlir::acc::OpenACCDialect, \
|
||||
mlir::omp::OpenMPDialect, mlir::scf::SCFDialect, \
|
||||
mlir::arith::ArithmeticDialect, mlir::StandardOpsDialect, \
|
||||
mlir::vector::VectorDialect
|
||||
mlir::arith::ArithmeticDialect, mlir::cf::ControlFlowDialect, \
|
||||
mlir::StandardOpsDialect, mlir::vector::VectorDialect
|
||||
|
||||
// The definitive list of dialects used by flang.
|
||||
#define FLANG_DIALECT_LIST \
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
/// This file defines some shared command-line options that can be used when
|
||||
/// debugging the test tools. This file must be included into the tool.
|
||||
|
||||
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
|
||||
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
|
@ -139,7 +139,7 @@ inline void createDefaultFIROptimizerPassPipeline(mlir::PassManager &pm) {
|
|||
|
||||
// convert control flow to CFG form
|
||||
fir::addCfgConversionPass(pm);
|
||||
pm.addPass(mlir::createLowerToCFGPass());
|
||||
pm.addPass(mlir::createConvertSCFToCFPass());
|
||||
|
||||
pm.addPass(mlir::createCanonicalizerPass(config));
|
||||
}
|
||||
|
|
|
@ -32,7 +32,7 @@ add_flang_library(FortranLower
|
|||
FortranSemantics
|
||||
MLIRAffineToStandard
|
||||
MLIRLLVMIR
|
||||
MLIRSCFToStandard
|
||||
MLIRSCFToControlFlow
|
||||
MLIRStandard
|
||||
|
||||
LINK_COMPONENTS
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include "flang/Optimizer/Dialect/FIROps.h"
|
||||
#include "flang/Optimizer/Support/TypeCode.h"
|
||||
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
|
||||
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
|
||||
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
|
@ -3293,6 +3294,8 @@ public:
|
|||
mlir::populateStdToLLVMConversionPatterns(typeConverter, pattern);
|
||||
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
|
||||
pattern);
|
||||
mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter,
|
||||
pattern);
|
||||
mlir::ConversionTarget target{*context};
|
||||
target.addLegalDialect<mlir::LLVM::LLVMDialect>();
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
#include "flang/Optimizer/Dialect/FIRDialect.h"
|
||||
#include "flang/Optimizer/Support/FIRContext.h"
|
||||
#include "flang/Optimizer/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
@ -332,9 +333,9 @@ ArrayCopyAnalysis::arrayAccesses(ArrayLoadOp load) {
|
|||
<< "add modify {" << *owner << "} to array value set\n");
|
||||
accesses.push_back(owner);
|
||||
appendToQueue(update.getResult(1));
|
||||
} else if (auto br = mlir::dyn_cast<mlir::BranchOp>(owner)) {
|
||||
} else if (auto br = mlir::dyn_cast<mlir::cf::BranchOp>(owner)) {
|
||||
branchOp(br.getDest(), br.getDestOperands());
|
||||
} else if (auto br = mlir::dyn_cast<mlir::CondBranchOp>(owner)) {
|
||||
} else if (auto br = mlir::dyn_cast<mlir::cf::CondBranchOp>(owner)) {
|
||||
branchOp(br.getTrueDest(), br.getTrueOperands());
|
||||
branchOp(br.getFalseDest(), br.getFalseOperands());
|
||||
} else if (mlir::isa<ArrayMergeStoreOp>(owner)) {
|
||||
|
@ -789,9 +790,9 @@ public:
|
|||
patterns1.insert<ArrayUpdateConversion>(context, analysis, useMap);
|
||||
patterns1.insert<ArrayModifyConversion>(context, analysis, useMap);
|
||||
mlir::ConversionTarget target(*context);
|
||||
target.addLegalDialect<FIROpsDialect, mlir::scf::SCFDialect,
|
||||
mlir::arith::ArithmeticDialect,
|
||||
mlir::StandardOpsDialect>();
|
||||
target.addLegalDialect<
|
||||
FIROpsDialect, mlir::scf::SCFDialect, mlir::arith::ArithmeticDialect,
|
||||
mlir::cf::ControlFlowDialect, mlir::StandardOpsDialect>();
|
||||
target.addIllegalOp<ArrayFetchOp, ArrayUpdateOp, ArrayModifyOp>();
|
||||
// Rewrite the array fetch and array update ops.
|
||||
if (mlir::failed(
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
#include "flang/Optimizer/Dialect/FIROps.h"
|
||||
#include "flang/Optimizer/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
@ -84,7 +85,7 @@ public:
|
|||
loopOperands.append(operands.begin(), operands.end());
|
||||
loopOperands.push_back(iters);
|
||||
|
||||
rewriter.create<mlir::BranchOp>(loc, conditionalBlock, loopOperands);
|
||||
rewriter.create<mlir::cf::BranchOp>(loc, conditionalBlock, loopOperands);
|
||||
|
||||
// Last loop block
|
||||
auto *terminator = lastBlock->getTerminator();
|
||||
|
@ -105,7 +106,7 @@ public:
|
|||
: terminator->operand_begin();
|
||||
loopCarried.append(begin, terminator->operand_end());
|
||||
loopCarried.push_back(itersMinusOne);
|
||||
rewriter.create<mlir::BranchOp>(loc, conditionalBlock, loopCarried);
|
||||
rewriter.create<mlir::cf::BranchOp>(loc, conditionalBlock, loopCarried);
|
||||
rewriter.eraseOp(terminator);
|
||||
|
||||
// Conditional block
|
||||
|
@ -114,9 +115,9 @@ public:
|
|||
auto comparison = rewriter.create<mlir::arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::sgt, itersLeft, zero);
|
||||
|
||||
rewriter.create<mlir::CondBranchOp>(loc, comparison, firstBlock,
|
||||
llvm::ArrayRef<mlir::Value>(), endBlock,
|
||||
llvm::ArrayRef<mlir::Value>());
|
||||
rewriter.create<mlir::cf::CondBranchOp>(
|
||||
loc, comparison, firstBlock, llvm::ArrayRef<mlir::Value>(), endBlock,
|
||||
llvm::ArrayRef<mlir::Value>());
|
||||
|
||||
// The result of the loop operation is the values of the condition block
|
||||
// arguments except the induction variable on the last iteration.
|
||||
|
@ -155,7 +156,7 @@ public:
|
|||
} else {
|
||||
continueBlock =
|
||||
rewriter.createBlock(remainingOpsBlock, ifOp.getResultTypes());
|
||||
rewriter.create<mlir::BranchOp>(loc, remainingOpsBlock);
|
||||
rewriter.create<mlir::cf::BranchOp>(loc, remainingOpsBlock);
|
||||
}
|
||||
|
||||
// Move blocks from the "then" region to the region containing 'fir.if',
|
||||
|
@ -165,7 +166,8 @@ public:
|
|||
auto *ifOpTerminator = ifOpRegion.back().getTerminator();
|
||||
auto ifOpTerminatorOperands = ifOpTerminator->getOperands();
|
||||
rewriter.setInsertionPointToEnd(&ifOpRegion.back());
|
||||
rewriter.create<mlir::BranchOp>(loc, continueBlock, ifOpTerminatorOperands);
|
||||
rewriter.create<mlir::cf::BranchOp>(loc, continueBlock,
|
||||
ifOpTerminatorOperands);
|
||||
rewriter.eraseOp(ifOpTerminator);
|
||||
rewriter.inlineRegionBefore(ifOpRegion, continueBlock);
|
||||
|
||||
|
@ -179,14 +181,14 @@ public:
|
|||
auto *otherwiseTerm = otherwiseRegion.back().getTerminator();
|
||||
auto otherwiseTermOperands = otherwiseTerm->getOperands();
|
||||
rewriter.setInsertionPointToEnd(&otherwiseRegion.back());
|
||||
rewriter.create<mlir::BranchOp>(loc, continueBlock,
|
||||
otherwiseTermOperands);
|
||||
rewriter.create<mlir::cf::BranchOp>(loc, continueBlock,
|
||||
otherwiseTermOperands);
|
||||
rewriter.eraseOp(otherwiseTerm);
|
||||
rewriter.inlineRegionBefore(otherwiseRegion, continueBlock);
|
||||
}
|
||||
|
||||
rewriter.setInsertionPointToEnd(condBlock);
|
||||
rewriter.create<mlir::CondBranchOp>(
|
||||
rewriter.create<mlir::cf::CondBranchOp>(
|
||||
loc, ifOp.condition(), ifOpBlock, llvm::ArrayRef<mlir::Value>(),
|
||||
otherwiseBlock, llvm::ArrayRef<mlir::Value>());
|
||||
rewriter.replaceOp(ifOp, continueBlock->getArguments());
|
||||
|
@ -241,7 +243,7 @@ public:
|
|||
auto begin = whileOp.finalValue() ? std::next(terminator->operand_begin())
|
||||
: terminator->operand_begin();
|
||||
loopCarried.append(begin, terminator->operand_end());
|
||||
rewriter.create<mlir::BranchOp>(loc, conditionBlock, loopCarried);
|
||||
rewriter.create<mlir::cf::BranchOp>(loc, conditionBlock, loopCarried);
|
||||
rewriter.eraseOp(terminator);
|
||||
|
||||
// Compute loop bounds before branching to the condition.
|
||||
|
@ -256,7 +258,7 @@ public:
|
|||
destOperands.push_back(lowerBound);
|
||||
auto iterOperands = whileOp.getIterOperands();
|
||||
destOperands.append(iterOperands.begin(), iterOperands.end());
|
||||
rewriter.create<mlir::BranchOp>(loc, conditionBlock, destOperands);
|
||||
rewriter.create<mlir::cf::BranchOp>(loc, conditionBlock, destOperands);
|
||||
|
||||
// With the body block done, we can fill in the condition block.
|
||||
rewriter.setInsertionPointToEnd(conditionBlock);
|
||||
|
@ -278,9 +280,9 @@ public:
|
|||
// Remember to AND in the early-exit bool.
|
||||
auto comparison =
|
||||
rewriter.create<mlir::arith::AndIOp>(loc, iterateVar, cmp2);
|
||||
rewriter.create<mlir::CondBranchOp>(loc, comparison, firstBodyBlock,
|
||||
llvm::ArrayRef<mlir::Value>(), endBlock,
|
||||
llvm::ArrayRef<mlir::Value>());
|
||||
rewriter.create<mlir::cf::CondBranchOp>(
|
||||
loc, comparison, firstBodyBlock, llvm::ArrayRef<mlir::Value>(),
|
||||
endBlock, llvm::ArrayRef<mlir::Value>());
|
||||
// The result of the loop operation is the values of the condition block
|
||||
// arguments except the induction variable on the last iteration.
|
||||
auto args = whileOp.finalValue()
|
||||
|
@ -300,8 +302,8 @@ public:
|
|||
patterns.insert<CfgLoopConv, CfgIfConv, CfgIterWhileConv>(
|
||||
context, forceLoopToExecuteOnce);
|
||||
mlir::ConversionTarget target(*context);
|
||||
target.addLegalDialect<mlir::AffineDialect, FIROpsDialect,
|
||||
mlir::StandardOpsDialect>();
|
||||
target.addLegalDialect<mlir::AffineDialect, mlir::cf::ControlFlowDialect,
|
||||
FIROpsDialect, mlir::StandardOpsDialect>();
|
||||
|
||||
// apply the patterns
|
||||
target.addIllegalOp<ResultOp, DoLoopOp, IfOp, IterWhileOp>();
|
||||
|
|
|
@ -10,10 +10,10 @@ func @select_case_charachter(%arg0: !fir.char<2, 10>, %arg1: !fir.char<2, 10>, %
|
|||
unit, ^bb3]
|
||||
^bb1:
|
||||
%c1_i32 = arith.constant 1 : i32
|
||||
br ^bb3
|
||||
cf.br ^bb3
|
||||
^bb2:
|
||||
%c2_i32 = arith.constant 2 : i32
|
||||
br ^bb3
|
||||
cf.br ^bb3
|
||||
^bb3:
|
||||
return
|
||||
}
|
||||
|
|
|
@ -1175,23 +1175,23 @@ func @select_case_integer(%arg0: !fir.ref<i32>) -> i32 {
|
|||
^bb1: // pred: ^bb0
|
||||
%c1_i32_0 = arith.constant 1 : i32
|
||||
fir.store %c1_i32_0 to %arg0 : !fir.ref<i32>
|
||||
br ^bb6
|
||||
cf.br ^bb6
|
||||
^bb2: // pred: ^bb0
|
||||
%c2_i32_1 = arith.constant 2 : i32
|
||||
fir.store %c2_i32_1 to %arg0 : !fir.ref<i32>
|
||||
br ^bb6
|
||||
cf.br ^bb6
|
||||
^bb3: // pred: ^bb0
|
||||
%c0_i32 = arith.constant 0 : i32
|
||||
fir.store %c0_i32 to %arg0 : !fir.ref<i32>
|
||||
br ^bb6
|
||||
cf.br ^bb6
|
||||
^bb4: // pred: ^bb0
|
||||
%c4_i32_2 = arith.constant 4 : i32
|
||||
fir.store %c4_i32_2 to %arg0 : !fir.ref<i32>
|
||||
br ^bb6
|
||||
cf.br ^bb6
|
||||
^bb5: // 3 preds: ^bb0, ^bb0, ^bb0
|
||||
%c7_i32_3 = arith.constant 7 : i32
|
||||
fir.store %c7_i32_3 to %arg0 : !fir.ref<i32>
|
||||
br ^bb6
|
||||
cf.br ^bb6
|
||||
^bb6: // 5 preds: ^bb1, ^bb2, ^bb3, ^bb4, ^bb5
|
||||
%3 = fir.load %arg0 : !fir.ref<i32>
|
||||
return %3 : i32
|
||||
|
@ -1275,10 +1275,10 @@ func @select_case_logical(%arg0: !fir.ref<!fir.logical<4>>) {
|
|||
unit, ^bb3]
|
||||
^bb1:
|
||||
%c1_i32 = arith.constant 1 : i32
|
||||
br ^bb3
|
||||
cf.br ^bb3
|
||||
^bb2:
|
||||
%c2_i32 = arith.constant 2 : i32
|
||||
br ^bb3
|
||||
cf.br ^bb3
|
||||
^bb3:
|
||||
return
|
||||
}
|
||||
|
|
|
@ -9,10 +9,10 @@ func @load_store_chain_removal(%arg0: !fir.ref<!fir.array<60xi32>>, %arg1: !fir.
|
|||
%c1 = arith.constant 1 : index
|
||||
%0 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFf1dcEi"}
|
||||
%1 = fir.alloca !fir.array<60xi32> {bindc_name = "t1", uniq_name = "_QFf1dcEt1"}
|
||||
br ^bb1(%c1, %c60 : index, index)
|
||||
cf.br ^bb1(%c1, %c60 : index, index)
|
||||
^bb1(%2: index, %3: index): // 2 preds: ^bb0, ^bb2
|
||||
%4 = arith.cmpi sgt, %3, %c0 : index
|
||||
cond_br %4, ^bb2, ^bb3
|
||||
cf.cond_br %4, ^bb2, ^bb3
|
||||
^bb2: // pred: ^bb1
|
||||
%5 = fir.convert %2 : (index) -> i32
|
||||
fir.store %5 to %0 : !fir.ref<i32>
|
||||
|
@ -26,14 +26,14 @@ func @load_store_chain_removal(%arg0: !fir.ref<!fir.array<60xi32>>, %arg1: !fir.
|
|||
fir.store %11 to %12 : !fir.ref<i32>
|
||||
%13 = arith.addi %2, %c1 : index
|
||||
%14 = arith.subi %3, %c1 : index
|
||||
br ^bb1(%13, %14 : index, index)
|
||||
cf.br ^bb1(%13, %14 : index, index)
|
||||
^bb3: // pred: ^bb1
|
||||
%15 = fir.convert %2 : (index) -> i32
|
||||
fir.store %15 to %0 : !fir.ref<i32>
|
||||
br ^bb4(%c1, %c60 : index, index)
|
||||
cf.br ^bb4(%c1, %c60 : index, index)
|
||||
^bb4(%16: index, %17: index): // 2 preds: ^bb3, ^bb5
|
||||
%18 = arith.cmpi sgt, %17, %c0 : index
|
||||
cond_br %18, ^bb5, ^bb6
|
||||
cf.cond_br %18, ^bb5, ^bb6
|
||||
^bb5: // pred: ^bb4
|
||||
%19 = fir.convert %16 : (index) -> i32
|
||||
fir.store %19 to %0 : !fir.ref<i32>
|
||||
|
@ -49,7 +49,7 @@ func @load_store_chain_removal(%arg0: !fir.ref<!fir.array<60xi32>>, %arg1: !fir.
|
|||
fir.store %27 to %28 : !fir.ref<i32>
|
||||
%29 = arith.addi %16, %c1 : index
|
||||
%30 = arith.subi %17, %c1 : index
|
||||
br ^bb4(%29, %30 : index, index)
|
||||
cf.br ^bb4(%29, %30 : index, index)
|
||||
^bb6: // pred: ^bb4
|
||||
%31 = fir.convert %16 : (index) -> i32
|
||||
fir.store %31 to %0 : !fir.ref<i32>
|
||||
|
|
|
@ -13,7 +13,7 @@ FIRTransforms
|
|||
FIRBuilder
|
||||
${dialect_libs}
|
||||
MLIRAffineToStandard
|
||||
MLIRSCFToStandard
|
||||
MLIRSCFToControlFlow
|
||||
FortranCommon
|
||||
FortranParser
|
||||
FortranEvaluate
|
||||
|
|
|
@ -38,7 +38,6 @@
|
|||
#include "flang/Semantics/semantics.h"
|
||||
#include "flang/Semantics/unparse-with-symbols.h"
|
||||
#include "flang/Version.inc"
|
||||
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
|
||||
#include "mlir/IR/AsmState.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
|
|
|
@ -18,7 +18,7 @@ target_link_libraries(fir-opt PRIVATE
|
|||
MLIRTransforms
|
||||
MLIRAffineToStandard
|
||||
MLIRAnalysis
|
||||
MLIRSCFToStandard
|
||||
MLIRSCFToControlFlow
|
||||
MLIRParser
|
||||
MLIRStandardToLLVM
|
||||
MLIRSupport
|
||||
|
|
|
@ -17,7 +17,7 @@ target_link_libraries(tco PRIVATE
|
|||
MLIRTransforms
|
||||
MLIRAffineToStandard
|
||||
MLIRAnalysis
|
||||
MLIRSCFToStandard
|
||||
MLIRSCFToControlFlow
|
||||
MLIRParser
|
||||
MLIRStandardToLLVM
|
||||
MLIRSupport
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
#include "flang/Optimizer/Support/InternalNames.h"
|
||||
#include "flang/Optimizer/Support/KindMapping.h"
|
||||
#include "flang/Optimizer/Transforms/Passes.h"
|
||||
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
|
||||
#include "mlir/IR/AsmState.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
|
|
|
@ -26,7 +26,7 @@ def setup_passes(mlir_module):
|
|||
f"sparse-tensor-conversion,"
|
||||
f"builtin.func"
|
||||
f"(linalg-bufferize,convert-linalg-to-loops,convert-vector-to-scf),"
|
||||
f"convert-scf-to-std,"
|
||||
f"convert-scf-to-cf,"
|
||||
f"func-bufferize,"
|
||||
f"arith-bufferize,"
|
||||
f"builtin.func(tensor-bufferize,finalizing-bufferize),"
|
||||
|
|
|
@ -41,12 +41,12 @@ Example for breaking the invariant:
|
|||
```mlir
|
||||
func @condBranch(%arg0: i1, %arg1: memref<2xf32>) {
|
||||
%0 = memref.alloc() : memref<2xf32>
|
||||
cond_br %arg0, ^bb1, ^bb2
|
||||
cf.cond_br %arg0, ^bb1, ^bb2
|
||||
^bb1:
|
||||
br ^bb3()
|
||||
cf.br ^bb3()
|
||||
^bb2:
|
||||
partial_write(%0, %0)
|
||||
br ^bb3()
|
||||
cf.br ^bb3()
|
||||
^bb3():
|
||||
test.copy(%0, %arg1) : (memref<2xf32>, memref<2xf32>) -> ()
|
||||
return
|
||||
|
@ -74,13 +74,13 @@ untracked allocations are mixed:
|
|||
func @mixedAllocation(%arg0: i1) {
|
||||
%0 = memref.alloca() : memref<2xf32> // aliases: %2
|
||||
%1 = memref.alloc() : memref<2xf32> // aliases: %2
|
||||
cond_br %arg0, ^bb1, ^bb2
|
||||
cf.cond_br %arg0, ^bb1, ^bb2
|
||||
^bb1:
|
||||
use(%0)
|
||||
br ^bb3(%0 : memref<2xf32>)
|
||||
cf.br ^bb3(%0 : memref<2xf32>)
|
||||
^bb2:
|
||||
use(%1)
|
||||
br ^bb3(%1 : memref<2xf32>)
|
||||
cf.br ^bb3(%1 : memref<2xf32>)
|
||||
^bb3(%2: memref<2xf32>):
|
||||
...
|
||||
}
|
||||
|
@ -129,13 +129,13 @@ BufferHoisting pass:
|
|||
|
||||
```mlir
|
||||
func @condBranch(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
|
||||
cond_br %arg0, ^bb1, ^bb2
|
||||
cf.cond_br %arg0, ^bb1, ^bb2
|
||||
^bb1:
|
||||
br ^bb3(%arg1 : memref<2xf32>)
|
||||
cf.br ^bb3(%arg1 : memref<2xf32>)
|
||||
^bb2:
|
||||
%0 = memref.alloc() : memref<2xf32> // aliases: %1
|
||||
use(%0)
|
||||
br ^bb3(%0 : memref<2xf32>)
|
||||
cf.br ^bb3(%0 : memref<2xf32>)
|
||||
^bb3(%1: memref<2xf32>): // %1 could be %0 or %arg1
|
||||
test.copy(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
|
||||
return
|
||||
|
@ -150,12 +150,12 @@ of code:
|
|||
```mlir
|
||||
func @condBranch(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
|
||||
%0 = memref.alloc() : memref<2xf32> // moved to bb0
|
||||
cond_br %arg0, ^bb1, ^bb2
|
||||
cf.cond_br %arg0, ^bb1, ^bb2
|
||||
^bb1:
|
||||
br ^bb3(%arg1 : memref<2xf32>)
|
||||
cf.br ^bb3(%arg1 : memref<2xf32>)
|
||||
^bb2:
|
||||
use(%0)
|
||||
br ^bb3(%0 : memref<2xf32>)
|
||||
cf.br ^bb3(%0 : memref<2xf32>)
|
||||
^bb3(%1: memref<2xf32>):
|
||||
test.copy(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
|
||||
return
|
||||
|
@ -175,14 +175,14 @@ func @condBranchDynamicType(
|
|||
%arg1: memref<?xf32>,
|
||||
%arg2: memref<?xf32>,
|
||||
%arg3: index) {
|
||||
cond_br %arg0, ^bb1, ^bb2(%arg3: index)
|
||||
cf.cond_br %arg0, ^bb1, ^bb2(%arg3: index)
|
||||
^bb1:
|
||||
br ^bb3(%arg1 : memref<?xf32>)
|
||||
cf.br ^bb3(%arg1 : memref<?xf32>)
|
||||
^bb2(%0: index):
|
||||
%1 = memref.alloc(%0) : memref<?xf32> // cannot be moved upwards to the data
|
||||
// dependency to %0
|
||||
use(%1)
|
||||
br ^bb3(%1 : memref<?xf32>)
|
||||
cf.br ^bb3(%1 : memref<?xf32>)
|
||||
^bb3(%2: memref<?xf32>):
|
||||
test.copy(%2, %arg2) : (memref<?xf32>, memref<?xf32>) -> ()
|
||||
return
|
||||
|
@ -201,14 +201,14 @@ allocations have already been placed:
|
|||
```mlir
|
||||
func @branch(%arg0: i1) {
|
||||
%0 = memref.alloc() : memref<2xf32> // aliases: %2
|
||||
cond_br %arg0, ^bb1, ^bb2
|
||||
cf.cond_br %arg0, ^bb1, ^bb2
|
||||
^bb1:
|
||||
%1 = memref.alloc() : memref<2xf32> // resides here for demonstration purposes
|
||||
// aliases: %2
|
||||
br ^bb3(%1 : memref<2xf32>)
|
||||
cf.br ^bb3(%1 : memref<2xf32>)
|
||||
^bb2:
|
||||
use(%0)
|
||||
br ^bb3(%0 : memref<2xf32>)
|
||||
cf.br ^bb3(%0 : memref<2xf32>)
|
||||
^bb3(%2: memref<2xf32>):
|
||||
…
|
||||
return
|
||||
|
@ -233,16 +233,16 @@ result:
|
|||
```mlir
|
||||
func @branch(%arg0: i1) {
|
||||
%0 = memref.alloc() : memref<2xf32>
|
||||
cond_br %arg0, ^bb1, ^bb2
|
||||
cf.cond_br %arg0, ^bb1, ^bb2
|
||||
^bb1:
|
||||
%1 = memref.alloc() : memref<2xf32>
|
||||
%3 = bufferization.clone %1 : (memref<2xf32>) -> (memref<2xf32>)
|
||||
memref.dealloc %1 : memref<2xf32> // %1 can be safely freed here
|
||||
br ^bb3(%3 : memref<2xf32>)
|
||||
cf.br ^bb3(%3 : memref<2xf32>)
|
||||
^bb2:
|
||||
use(%0)
|
||||
%4 = bufferization.clone %0 : (memref<2xf32>) -> (memref<2xf32>)
|
||||
br ^bb3(%4 : memref<2xf32>)
|
||||
cf.br ^bb3(%4 : memref<2xf32>)
|
||||
^bb3(%2: memref<2xf32>):
|
||||
…
|
||||
memref.dealloc %2 : memref<2xf32> // free temp buffer %2
|
||||
|
@ -273,23 +273,23 @@ func @condBranchDynamicTypeNested(
|
|||
%arg1: memref<?xf32>, // aliases: %3, %4
|
||||
%arg2: memref<?xf32>,
|
||||
%arg3: index) {
|
||||
cond_br %arg0, ^bb1, ^bb2(%arg3: index)
|
||||
cf.cond_br %arg0, ^bb1, ^bb2(%arg3: index)
|
||||
^bb1:
|
||||
br ^bb6(%arg1 : memref<?xf32>)
|
||||
cf.br ^bb6(%arg1 : memref<?xf32>)
|
||||
^bb2(%0: index):
|
||||
%1 = memref.alloc(%0) : memref<?xf32> // cannot be moved upwards due to the data
|
||||
// dependency to %0
|
||||
// aliases: %2, %3, %4
|
||||
use(%1)
|
||||
cond_br %arg0, ^bb3, ^bb4
|
||||
cf.cond_br %arg0, ^bb3, ^bb4
|
||||
^bb3:
|
||||
br ^bb5(%1 : memref<?xf32>)
|
||||
cf.br ^bb5(%1 : memref<?xf32>)
|
||||
^bb4:
|
||||
br ^bb5(%1 : memref<?xf32>)
|
||||
cf.br ^bb5(%1 : memref<?xf32>)
|
||||
^bb5(%2: memref<?xf32>): // non-crit. alias of %1, since %1 dominates %2
|
||||
br ^bb6(%2 : memref<?xf32>)
|
||||
cf.br ^bb6(%2 : memref<?xf32>)
|
||||
^bb6(%3: memref<?xf32>): // crit. alias of %arg1 and %2 (in other words %1)
|
||||
br ^bb7(%3 : memref<?xf32>)
|
||||
cf.br ^bb7(%3 : memref<?xf32>)
|
||||
^bb7(%4: memref<?xf32>): // non-crit. alias of %3, since %3 dominates %4
|
||||
test.copy(%4, %arg2) : (memref<?xf32>, memref<?xf32>) -> ()
|
||||
return
|
||||
|
@ -306,25 +306,25 @@ func @condBranchDynamicTypeNested(
|
|||
%arg1: memref<?xf32>,
|
||||
%arg2: memref<?xf32>,
|
||||
%arg3: index) {
|
||||
cond_br %arg0, ^bb1, ^bb2(%arg3 : index)
|
||||
cf.cond_br %arg0, ^bb1, ^bb2(%arg3 : index)
|
||||
^bb1:
|
||||
// temp buffer required due to alias %3
|
||||
%5 = bufferization.clone %arg1 : (memref<?xf32>) -> (memref<?xf32>)
|
||||
br ^bb6(%5 : memref<?xf32>)
|
||||
cf.br ^bb6(%5 : memref<?xf32>)
|
||||
^bb2(%0: index):
|
||||
%1 = memref.alloc(%0) : memref<?xf32>
|
||||
use(%1)
|
||||
cond_br %arg0, ^bb3, ^bb4
|
||||
cf.cond_br %arg0, ^bb3, ^bb4
|
||||
^bb3:
|
||||
br ^bb5(%1 : memref<?xf32>)
|
||||
cf.br ^bb5(%1 : memref<?xf32>)
|
||||
^bb4:
|
||||
br ^bb5(%1 : memref<?xf32>)
|
||||
cf.br ^bb5(%1 : memref<?xf32>)
|
||||
^bb5(%2: memref<?xf32>):
|
||||
%6 = bufferization.clone %1 : (memref<?xf32>) -> (memref<?xf32>)
|
||||
memref.dealloc %1 : memref<?xf32>
|
||||
br ^bb6(%6 : memref<?xf32>)
|
||||
cf.br ^bb6(%6 : memref<?xf32>)
|
||||
^bb6(%3: memref<?xf32>):
|
||||
br ^bb7(%3 : memref<?xf32>)
|
||||
cf.br ^bb7(%3 : memref<?xf32>)
|
||||
^bb7(%4: memref<?xf32>):
|
||||
test.copy(%4, %arg2) : (memref<?xf32>, memref<?xf32>) -> ()
|
||||
memref.dealloc %3 : memref<?xf32> // free %3, since %4 is a non-crit. alias of %3
|
||||
|
|
|
@ -295,7 +295,7 @@ A few examples are shown below:
|
|||
```mlir
|
||||
// Expect an error on the same line.
|
||||
func @bad_branch() {
|
||||
br ^missing // expected-error {{reference to an undefined block}}
|
||||
cf.br ^missing // expected-error {{reference to an undefined block}}
|
||||
}
|
||||
|
||||
// Expect an error on an adjacent line.
|
||||
|
|
|
@ -114,8 +114,8 @@ struct MyTarget : public ConversionTarget {
|
|||
/// All operations within the GPU dialect are illegal.
|
||||
addIllegalDialect<GPUDialect>();
|
||||
|
||||
/// Mark `std.br` and `std.cond_br` as illegal.
|
||||
addIllegalOp<BranchOp, CondBranchOp>();
|
||||
/// Mark `cf.br` and `cf.cond_br` as illegal.
|
||||
addIllegalOp<cf::BranchOp, cf::CondBranchOp>();
|
||||
}
|
||||
|
||||
/// Implement the default legalization handler to handle operations marked as
|
||||
|
|
|
@ -23,10 +23,11 @@ argument `-declare-variables-at-top`.
|
|||
Besides operations part of the EmitC dialect, the Cpp targets supports
|
||||
translating the following operations:
|
||||
|
||||
* 'cf' Dialect
|
||||
* `cf.br`
|
||||
* `cf.cond_br`
|
||||
* 'std' Dialect
|
||||
* `std.br`
|
||||
* `std.call`
|
||||
* `std.cond_br`
|
||||
* `std.constant`
|
||||
* `std.return`
|
||||
* 'scf' Dialect
|
||||
|
|
|
@ -391,21 +391,21 @@ arguments:
|
|||
```mlir
|
||||
func @simple(i64, i1) -> i64 {
|
||||
^bb0(%a: i64, %cond: i1): // Code dominated by ^bb0 may refer to %a
|
||||
cond_br %cond, ^bb1, ^bb2
|
||||
cf.cond_br %cond, ^bb1, ^bb2
|
||||
|
||||
^bb1:
|
||||
br ^bb3(%a: i64) // Branch passes %a as the argument
|
||||
cf.br ^bb3(%a: i64) // Branch passes %a as the argument
|
||||
|
||||
^bb2:
|
||||
%b = arith.addi %a, %a : i64
|
||||
br ^bb3(%b: i64) // Branch passes %b as the argument
|
||||
cf.br ^bb3(%b: i64) // Branch passes %b as the argument
|
||||
|
||||
// ^bb3 receives an argument, named %c, from predecessors
|
||||
// and passes it on to bb4 along with %a. %a is referenced
|
||||
// directly from its defining operation and is not passed through
|
||||
// an argument of ^bb3.
|
||||
^bb3(%c: i64):
|
||||
br ^bb4(%c, %a : i64, i64)
|
||||
cf.br ^bb4(%c, %a : i64, i64)
|
||||
|
||||
^bb4(%d : i64, %e : i64):
|
||||
%0 = arith.addi %d, %e : i64
|
||||
|
@ -525,12 +525,12 @@ Example:
|
|||
```mlir
|
||||
func @accelerator_compute(i64, i1) -> i64 { // An SSACFG region
|
||||
^bb0(%a: i64, %cond: i1): // Code dominated by ^bb0 may refer to %a
|
||||
cond_br %cond, ^bb1, ^bb2
|
||||
cf.cond_br %cond, ^bb1, ^bb2
|
||||
|
||||
^bb1:
|
||||
// This def for %value does not dominate ^bb2
|
||||
%value = "op.convert"(%a) : (i64) -> i64
|
||||
br ^bb3(%a: i64) // Branch passes %a as the argument
|
||||
cf.br ^bb3(%a: i64) // Branch passes %a as the argument
|
||||
|
||||
^bb2:
|
||||
accelerator.launch() { // An SSACFG region
|
||||
|
|
|
@ -356,24 +356,24 @@ Example output is shown below:
|
|||
|
||||
```
|
||||
//===-------------------------------------------===//
|
||||
Processing operation : 'std.cond_br'(0x60f000001120) {
|
||||
"std.cond_br"(%arg0)[^bb2, ^bb2] {operand_segment_sizes = dense<[1, 0, 0]> : vector<3xi32>} : (i1) -> ()
|
||||
Processing operation : 'cf.cond_br'(0x60f000001120) {
|
||||
"cf.cond_br"(%arg0)[^bb2, ^bb2] {operand_segment_sizes = dense<[1, 0, 0]> : vector<3xi32>} : (i1) -> ()
|
||||
|
||||
* Pattern SimplifyConstCondBranchPred : 'std.cond_br -> ()' {
|
||||
* Pattern SimplifyConstCondBranchPred : 'cf.cond_br -> ()' {
|
||||
} -> failure : pattern failed to match
|
||||
|
||||
* Pattern SimplifyCondBranchIdenticalSuccessors : 'std.cond_br -> ()' {
|
||||
** Insert : 'std.br'(0x60b000003690)
|
||||
** Replace : 'std.cond_br'(0x60f000001120)
|
||||
* Pattern SimplifyCondBranchIdenticalSuccessors : 'cf.cond_br -> ()' {
|
||||
** Insert : 'cf.br'(0x60b000003690)
|
||||
** Replace : 'cf.cond_br'(0x60f000001120)
|
||||
} -> success : pattern applied successfully
|
||||
} -> success : pattern matched
|
||||
//===-------------------------------------------===//
|
||||
```
|
||||
|
||||
This output is describing the processing of a `std.cond_br` operation. We first
|
||||
This output is describing the processing of a `cf.cond_br` operation. We first
|
||||
try to apply the `SimplifyConstCondBranchPred`, which fails. From there, another
|
||||
pattern (`SimplifyCondBranchIdenticalSuccessors`) is applied that matches the
|
||||
`std.cond_br` and replaces it with a `std.br`.
|
||||
`cf.cond_br` and replaces it with a `cf.br`.
|
||||
|
||||
## Debugging
|
||||
|
||||
|
|
|
@ -560,24 +560,24 @@ func @search(%A: memref<?x?xi32>, %S: <?xi32>, %key : i32) {
|
|||
|
||||
func @search_body(%A: memref<?x?xi32>, %S: memref<?xi32>, %key: i32, %i : i32) {
|
||||
%nj = memref.dim %A, 1 : memref<?x?xi32>
|
||||
br ^bb1(0)
|
||||
cf.br ^bb1(0)
|
||||
|
||||
^bb1(%j: i32)
|
||||
%p1 = arith.cmpi "lt", %j, %nj : i32
|
||||
cond_br %p1, ^bb2, ^bb5
|
||||
cf.cond_br %p1, ^bb2, ^bb5
|
||||
|
||||
^bb2:
|
||||
%v = affine.load %A[%i, %j] : memref<?x?xi32>
|
||||
%p2 = arith.cmpi "eq", %v, %key : i32
|
||||
cond_br %p2, ^bb3(%j), ^bb4
|
||||
cf.cond_br %p2, ^bb3(%j), ^bb4
|
||||
|
||||
^bb3(%j: i32)
|
||||
affine.store %j, %S[%i] : memref<?xi32>
|
||||
br ^bb5
|
||||
cf.br ^bb5
|
||||
|
||||
^bb4:
|
||||
%jinc = arith.addi %j, 1 : i32
|
||||
br ^bb1(%jinc)
|
||||
cf.br ^bb1(%jinc)
|
||||
|
||||
^bb5:
|
||||
return
|
||||
|
|
|
@ -94,10 +94,11 @@ multiple stages by relying on
|
|||
```c++
|
||||
mlir::RewritePatternSet patterns(&getContext());
|
||||
mlir::populateAffineToStdConversionPatterns(patterns, &getContext());
|
||||
mlir::populateLoopToStdConversionPatterns(patterns, &getContext());
|
||||
mlir::cf::populateSCFToControlFlowConversionPatterns(patterns, &getContext());
|
||||
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
|
||||
patterns);
|
||||
mlir::populateStdToLLVMConversionPatterns(typeConverter, patterns);
|
||||
mlir::cf::populateControlFlowToLLVMConversionPatterns(patterns, &getContext());
|
||||
|
||||
// The only remaining operation, to lower from the `toy` dialect, is the
|
||||
// PrintOp.
|
||||
|
@ -207,7 +208,7 @@ define void @main() {
|
|||
%109 = memref.load double, double* %108
|
||||
%110 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double %109)
|
||||
%111 = add i64 %100, 1
|
||||
br label %99
|
||||
cf.br label %99
|
||||
|
||||
...
|
||||
|
||||
|
|
|
@ -361,7 +361,7 @@
|
|||
</tspan></tspan><tspan
|
||||
x="73.476562"
|
||||
y="88.293896"><tspan
|
||||
style="font-size:5.64444px">br bb3(%0)</tspan></tspan></text>
|
||||
style="font-size:5.64444px">cf.br bb3(%0)</tspan></tspan></text>
|
||||
<text
|
||||
xml:space="preserve"
|
||||
id="text1894"
|
||||
|
|
Before Width: | Height: | Size: 15 KiB After Width: | Height: | Size: 15 KiB |
|
@ -362,7 +362,7 @@
|
|||
</tspan></tspan><tspan
|
||||
x="73.476562"
|
||||
y="88.293896"><tspan
|
||||
style="font-size:5.64444px">br bb3(%0)</tspan></tspan></text>
|
||||
style="font-size:5.64444px">cf.br bb3(%0)</tspan></tspan></text>
|
||||
<text
|
||||
xml:space="preserve"
|
||||
id="text1894"
|
||||
|
|
Before Width: | Height: | Size: 15 KiB After Width: | Height: | Size: 15 KiB |
|
@ -26,10 +26,11 @@
|
|||
|
||||
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
|
||||
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
|
||||
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
|
||||
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
|
||||
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
|
||||
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
|
||||
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
|
||||
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
|
@ -200,10 +201,11 @@ void ToyToLLVMLoweringPass::runOnOperation() {
|
|||
// set of legal ones.
|
||||
RewritePatternSet patterns(&getContext());
|
||||
populateAffineToStdConversionPatterns(patterns);
|
||||
populateLoopToStdConversionPatterns(patterns);
|
||||
populateSCFToControlFlowConversionPatterns(patterns);
|
||||
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
|
||||
patterns);
|
||||
populateMemRefToLLVMConversionPatterns(typeConverter, patterns);
|
||||
cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns);
|
||||
populateStdToLLVMConversionPatterns(typeConverter, patterns);
|
||||
|
||||
// The only remaining operation to lower from the `toy` dialect, is the
|
||||
|
|
|
@ -26,10 +26,11 @@
|
|||
|
||||
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
|
||||
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
|
||||
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
|
||||
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
|
||||
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
|
||||
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
|
||||
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
|
||||
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
|
@ -200,10 +201,11 @@ void ToyToLLVMLoweringPass::runOnOperation() {
|
|||
// set of legal ones.
|
||||
RewritePatternSet patterns(&getContext());
|
||||
populateAffineToStdConversionPatterns(patterns);
|
||||
populateLoopToStdConversionPatterns(patterns);
|
||||
populateSCFToControlFlowConversionPatterns(patterns);
|
||||
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
|
||||
patterns);
|
||||
populateMemRefToLLVMConversionPatterns(typeConverter, patterns);
|
||||
cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns);
|
||||
populateStdToLLVMConversionPatterns(typeConverter, patterns);
|
||||
|
||||
// The only remaining operation to lower from the `toy` dialect, is the
|
||||
|
|
|
@ -0,0 +1,35 @@
|
|||
//===- ControlFlowToLLVM.h - ControlFlow to LLVM -----------*- C++ ------*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Define conversions from the ControlFlow dialect to the LLVM IR dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_CONVERSION_CONTROLFLOWTOLLVM_CONTROLFLOWTOLLVM_H
|
||||
#define MLIR_CONVERSION_CONTROLFLOWTOLLVM_CONTROLFLOWTOLLVM_H
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace mlir {
|
||||
class LLVMTypeConverter;
|
||||
class RewritePatternSet;
|
||||
class Pass;
|
||||
|
||||
namespace cf {
|
||||
/// Collect the patterns to convert from the ControlFlow dialect to LLVM. The
|
||||
/// conversion patterns capture the LLVMTypeConverter by reference meaning the
|
||||
/// references have to remain alive during the entire pattern lifetime.
|
||||
void populateControlFlowToLLVMConversionPatterns(LLVMTypeConverter &converter,
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
/// Creates a pass to convert the ControlFlow dialect into the LLVMIR dialect.
|
||||
std::unique_ptr<Pass> createConvertControlFlowToLLVMPass();
|
||||
} // namespace cf
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_CONVERSION_CONTROLFLOWTOLLVM_CONTROLFLOWTOLLVM_H
|
|
@ -0,0 +1,28 @@
|
|||
//===- ControlFlowToSPIRV.h - CF to SPIR-V Patterns --------*- C++ ------*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Provides patterns to convert ControlFlow dialect to SPIR-V dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_CONVERSION_CONTROLFLOWTOSPIRV_CONTROLFLOWTOSPIRV_H
|
||||
#define MLIR_CONVERSION_CONTROLFLOWTOSPIRV_CONTROLFLOWTOSPIRV_H
|
||||
|
||||
namespace mlir {
|
||||
class RewritePatternSet;
|
||||
class SPIRVTypeConverter;
|
||||
|
||||
namespace cf {
|
||||
/// Appends to a pattern list additional patterns for translating ControlFLow
|
||||
/// ops to SPIR-V ops.
|
||||
void populateControlFlowToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns);
|
||||
} // namespace cf
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_CONVERSION_CONTROLFLOWTOSPIRV_CONTROLFLOWTOSPIRV_H
|
|
@ -17,6 +17,8 @@
|
|||
#include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h"
|
||||
#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
|
||||
#include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
|
||||
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
|
||||
#include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h"
|
||||
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
|
||||
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
|
||||
#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
|
||||
|
@ -35,10 +37,10 @@
|
|||
#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
|
||||
#include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
|
||||
#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
|
||||
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
|
||||
#include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h"
|
||||
#include "mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h"
|
||||
#include "mlir/Conversion/SCFToSPIRV/SCFToSPIRVPass.h"
|
||||
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
|
||||
#include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.h"
|
||||
#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
||||
|
|
|
@ -181,6 +181,28 @@ def ConvertComplexToStandard : Pass<"convert-complex-to-standard", "FuncOp"> {
|
|||
let dependentDialects = ["math::MathDialect"];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ControlFlowToLLVM
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def ConvertControlFlowToLLVM : Pass<"convert-cf-to-llvm", "ModuleOp"> {
|
||||
let summary = "Convert ControlFlow operations to the LLVM dialect";
|
||||
let description = [{
|
||||
Convert ControlFlow operations into LLVM IR dialect operations.
|
||||
|
||||
If other operations are present and their results are required by the LLVM
|
||||
IR dialect operations, the pass will fail. Any LLVM IR operations or types
|
||||
already present in the IR will be kept as is.
|
||||
}];
|
||||
let constructor = "mlir::cf::createConvertControlFlowToLLVMPass()";
|
||||
let dependentDialects = ["LLVM::LLVMDialect"];
|
||||
let options = [
|
||||
Option<"indexBitwidth", "index-bitwidth", "unsigned",
|
||||
/*default=kDeriveIndexBitwidthFromDataLayout*/"0",
|
||||
"Bitwidth of the index type, 0 to use size of machine word">,
|
||||
];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GPUCommon
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -460,6 +482,17 @@ def ReconcileUnrealizedCasts : Pass<"reconcile-unrealized-casts"> {
|
|||
let constructor = "mlir::createReconcileUnrealizedCastsPass()";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SCFToControlFlow
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def SCFToControlFlow : Pass<"convert-scf-to-cf"> {
|
||||
let summary = "Convert SCF dialect to ControlFlow dialect, replacing structured"
|
||||
" control flow with a CFG";
|
||||
let constructor = "mlir::createConvertSCFToCFPass()";
|
||||
let dependentDialects = ["cf::ControlFlowDialect"];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SCFToOpenMP
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -488,17 +521,6 @@ def SCFToSPIRV : Pass<"convert-scf-to-spirv", "ModuleOp"> {
|
|||
let dependentDialects = ["spirv::SPIRVDialect"];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SCFToStandard
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def SCFToStandard : Pass<"convert-scf-to-std"> {
|
||||
let summary = "Convert SCF dialect to Standard dialect, replacing structured"
|
||||
" control flow with a CFG";
|
||||
let constructor = "mlir::createLowerToCFGPass()";
|
||||
let dependentDialects = ["StandardOpsDialect"];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SCFToGPU
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -547,7 +569,7 @@ def ConvertShapeConstraints: Pass<"convert-shape-constraints", "FuncOp"> {
|
|||
computation lowering.
|
||||
}];
|
||||
let constructor = "mlir::createConvertShapeConstraintsPass()";
|
||||
let dependentDialects = ["StandardOpsDialect", "scf::SCFDialect"];
|
||||
let dependentDialects = ["cf::ControlFlowDialect", "scf::SCFDialect"];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
//===- ConvertSCFToControlFlow.h - Pass entrypoint --------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_CONVERSION_SCFTOCONTROLFLOW_SCFTOCONTROLFLOW_H_
|
||||
#define MLIR_CONVERSION_SCFTOCONTROLFLOW_SCFTOCONTROLFLOW_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace mlir {
|
||||
class Pass;
|
||||
class RewritePatternSet;
|
||||
|
||||
/// Collect a set of patterns to convert SCF operations to CFG branch-based
|
||||
/// operations within the ControlFlow dialect.
|
||||
void populateSCFToControlFlowConversionPatterns(RewritePatternSet &patterns);
|
||||
|
||||
/// Creates a pass to convert SCF operations to CFG branch-based operation in
|
||||
/// the ControlFlow dialect.
|
||||
std::unique_ptr<Pass> createConvertSCFToCFPass();
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_CONVERSION_SCFTOCONTROLFLOW_SCFTOCONTROLFLOW_H_
|
|
@ -1,31 +0,0 @@
|
|||
//===- ConvertSCFToStandard.h - Pass entrypoint -----------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_CONVERSION_SCFTOSTANDARD_SCFTOSTANDARD_H_
|
||||
#define MLIR_CONVERSION_SCFTOSTANDARD_SCFTOSTANDARD_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
namespace mlir {
|
||||
struct LogicalResult;
|
||||
class Pass;
|
||||
|
||||
class RewritePatternSet;
|
||||
|
||||
/// Collect a set of patterns to lower from scf.for, scf.if, and
|
||||
/// loop.terminator to CFG operations within the Standard dialect, in particular
|
||||
/// convert structured control flow into CFG branch-based control flow.
|
||||
void populateLoopToStdConversionPatterns(RewritePatternSet &patterns);
|
||||
|
||||
/// Creates a pass to convert scf.for, scf.if and loop.terminator ops to CFG.
|
||||
std::unique_ptr<Pass> createLowerToCFGPass();
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_CONVERSION_SCFTOSTANDARD_SCFTOSTANDARD_H_
|
|
@ -26,9 +26,9 @@ def BufferDeallocation : Pass<"buffer-deallocation", "FuncOp"> {
|
|||
#map0 = affine_map<(d0) -> (d0)>
|
||||
module {
|
||||
func @condBranch(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
|
||||
cond_br %arg0, ^bb1, ^bb2
|
||||
cf.cond_br %arg0, ^bb1, ^bb2
|
||||
^bb1:
|
||||
br ^bb3(%arg1 : memref<2xf32>)
|
||||
cf.br ^bb3(%arg1 : memref<2xf32>)
|
||||
^bb2:
|
||||
%0 = memref.alloc() : memref<2xf32>
|
||||
linalg.generic {
|
||||
|
@ -40,7 +40,7 @@ def BufferDeallocation : Pass<"buffer-deallocation", "FuncOp"> {
|
|||
%tmp1 = exp %gen1_arg0 : f32
|
||||
linalg.yield %tmp1 : f32
|
||||
}: memref<2xf32>, memref<2xf32>
|
||||
br ^bb3(%0 : memref<2xf32>)
|
||||
cf.br ^bb3(%0 : memref<2xf32>)
|
||||
^bb3(%1: memref<2xf32>):
|
||||
"memref.copy"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
|
||||
return
|
||||
|
@ -55,11 +55,11 @@ def BufferDeallocation : Pass<"buffer-deallocation", "FuncOp"> {
|
|||
#map0 = affine_map<(d0) -> (d0)>
|
||||
module {
|
||||
func @condBranch(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
|
||||
cond_br %arg0, ^bb1, ^bb2
|
||||
cf.cond_br %arg0, ^bb1, ^bb2
|
||||
^bb1: // pred: ^bb0
|
||||
%0 = memref.alloc() : memref<2xf32>
|
||||
memref.copy(%arg1, %0) : memref<2xf32>, memref<2xf32>
|
||||
br ^bb3(%0 : memref<2xf32>)
|
||||
cf.br ^bb3(%0 : memref<2xf32>)
|
||||
^bb2: // pred: ^bb0
|
||||
%1 = memref.alloc() : memref<2xf32>
|
||||
linalg.generic {
|
||||
|
@ -74,7 +74,7 @@ def BufferDeallocation : Pass<"buffer-deallocation", "FuncOp"> {
|
|||
%2 = memref.alloc() : memref<2xf32>
|
||||
memref.copy(%1, %2) : memref<2xf32>, memref<2xf32>
|
||||
dealloc %1 : memref<2xf32>
|
||||
br ^bb3(%2 : memref<2xf32>)
|
||||
cf.br ^bb3(%2 : memref<2xf32>)
|
||||
^bb3(%3: memref<2xf32>): // 2 preds: ^bb1, ^bb2
|
||||
memref.copy(%3, %arg2) : memref<2xf32>, memref<2xf32>
|
||||
dealloc %3 : memref<2xf32>
|
||||
|
|
|
@ -6,6 +6,7 @@ add_subdirectory(ArmSVE)
|
|||
add_subdirectory(AMX)
|
||||
add_subdirectory(Bufferization)
|
||||
add_subdirectory(Complex)
|
||||
add_subdirectory(ControlFlow)
|
||||
add_subdirectory(DLTI)
|
||||
add_subdirectory(EmitC)
|
||||
add_subdirectory(GPU)
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
add_subdirectory(IR)
|
|
@ -0,0 +1,2 @@
|
|||
add_mlir_dialect(ControlFlowOps cf ControlFlowOps)
|
||||
add_mlir_doc(ControlFlowOps ControlFlowDialect Dialects/ -gen-dialect-doc)
|
|
@ -0,0 +1,21 @@
|
|||
//===- ControlFlow.h - ControlFlow Dialect ----------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file defines the ControlFlow dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_CONTROLFLOW_IR_CONTROLFLOW_H
|
||||
#define MLIR_DIALECT_CONTROLFLOW_IR_CONTROLFLOW_H
|
||||
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOpsDialect.h.inc"
|
||||
|
||||
#endif // MLIR_DIALECT_CONTROLFLOW_IR_CONTROLFLOW_H
|
|
@ -0,0 +1,30 @@
|
|||
//===- ControlFlowOps.h - ControlFlow Operations ----------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file defines the operations of the ControlFlow dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_CONTROLFLOW_IR_CONTROLFLOWOPS_H
|
||||
#define MLIR_DIALECT_CONTROLFLOW_IR_CONTROLFLOWOPS_H
|
||||
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
|
||||
namespace mlir {
|
||||
class PatternRewriter;
|
||||
} // namespace mlir
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h.inc"
|
||||
|
||||
#endif // MLIR_DIALECT_CONTROLFLOW_IR_CONTROLFLOWOPS_H
|
|
@ -0,0 +1,313 @@
|
|||
//===- ControlFlowOps.td - ControlFlow operations ----------*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file contains definitions for the operations within the ControlFlow
|
||||
// dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef STANDARD_OPS
|
||||
#define STANDARD_OPS
|
||||
|
||||
include "mlir/IR/OpAsmInterface.td"
|
||||
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
|
||||
def ControlFlow_Dialect : Dialect {
|
||||
let name = "cf";
|
||||
let cppNamespace = "::mlir::cf";
|
||||
let dependentDialects = ["arith::ArithmeticDialect"];
|
||||
let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed;
|
||||
let description = [{
|
||||
This dialect contains low-level, i.e. non-region based, control flow
|
||||
constructs. These constructs generally represent control flow directly
|
||||
on SSA blocks of a control flow graph.
|
||||
}];
|
||||
}
|
||||
|
||||
class CF_Op<string mnemonic, list<Trait> traits = []> :
|
||||
Op<ControlFlow_Dialect, mnemonic, traits>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AssertOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def AssertOp : CF_Op<"assert"> {
|
||||
let summary = "Assert operation with message attribute";
|
||||
let description = [{
|
||||
Assert operation with single boolean operand and an error message attribute.
|
||||
If the argument is `true` this operation has no effect. Otherwise, the
|
||||
program execution will abort. The provided error message may be used by a
|
||||
runtime to propagate the error to the user.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
assert %b, "Expected ... to be true"
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins I1:$arg, StrAttr:$msg);
|
||||
|
||||
let assemblyFormat = "$arg `,` $msg attr-dict";
|
||||
let hasCanonicalizeMethod = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BranchOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def BranchOp : CF_Op<"br", [
|
||||
DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
|
||||
NoSideEffect, Terminator
|
||||
]> {
|
||||
let summary = "branch operation";
|
||||
let description = [{
|
||||
The `cf.br` operation represents a direct branch operation to a given
|
||||
block. The operands of this operation are forwarded to the successor block,
|
||||
and the number and type of the operands must match the arguments of the
|
||||
target block.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
^bb2:
|
||||
%2 = call @someFn()
|
||||
cf.br ^bb3(%2 : tensor<*xf32>)
|
||||
^bb3(%3: tensor<*xf32>):
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins Variadic<AnyType>:$destOperands);
|
||||
let successors = (successor AnySuccessor:$dest);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "Block *":$dest,
|
||||
CArg<"ValueRange", "{}">:$destOperands), [{
|
||||
$_state.addSuccessors(dest);
|
||||
$_state.addOperands(destOperands);
|
||||
}]>];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
void setDest(Block *block);
|
||||
|
||||
/// Erase the operand at 'index' from the operand list.
|
||||
void eraseOperand(unsigned index);
|
||||
}];
|
||||
|
||||
let hasCanonicalizeMethod = 1;
|
||||
let assemblyFormat = [{
|
||||
$dest (`(` $destOperands^ `:` type($destOperands) `)`)? attr-dict
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CondBranchOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def CondBranchOp : CF_Op<"cond_br",
|
||||
[AttrSizedOperandSegments,
|
||||
DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
|
||||
NoSideEffect, Terminator]> {
|
||||
let summary = "conditional branch operation";
|
||||
let description = [{
|
||||
The `cond_br` terminator operation represents a conditional branch on a
|
||||
boolean (1-bit integer) value. If the bit is set, then the first destination
|
||||
is jumped to; if it is false, the second destination is chosen. The count
|
||||
and types of operands must align with the arguments in the corresponding
|
||||
target blocks.
|
||||
|
||||
The MLIR conditional branch operation is not allowed to target the entry
|
||||
block for a region. The two destinations of the conditional branch operation
|
||||
are allowed to be the same.
|
||||
|
||||
The following example illustrates a function with a conditional branch
|
||||
operation that targets the same block.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
func @select(%a: i32, %b: i32, %flag: i1) -> i32 {
|
||||
// Both targets are the same, operands differ
|
||||
cond_br %flag, ^bb1(%a : i32), ^bb1(%b : i32)
|
||||
|
||||
^bb1(%x : i32) :
|
||||
return %x : i32
|
||||
}
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins I1:$condition,
|
||||
Variadic<AnyType>:$trueDestOperands,
|
||||
Variadic<AnyType>:$falseDestOperands);
|
||||
let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value":$condition, "Block *":$trueDest,
|
||||
"ValueRange":$trueOperands, "Block *":$falseDest,
|
||||
"ValueRange":$falseOperands), [{
|
||||
build($_builder, $_state, condition, trueOperands, falseOperands, trueDest,
|
||||
falseDest);
|
||||
}]>,
|
||||
OpBuilder<(ins "Value":$condition, "Block *":$trueDest,
|
||||
"Block *":$falseDest, CArg<"ValueRange", "{}">:$falseOperands), [{
|
||||
build($_builder, $_state, condition, trueDest, ValueRange(), falseDest,
|
||||
falseOperands);
|
||||
}]>];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// These are the indices into the dests list.
|
||||
enum { trueIndex = 0, falseIndex = 1 };
|
||||
|
||||
// Accessors for operands to the 'true' destination.
|
||||
Value getTrueOperand(unsigned idx) {
|
||||
assert(idx < getNumTrueOperands());
|
||||
return getOperand(getTrueDestOperandIndex() + idx);
|
||||
}
|
||||
|
||||
void setTrueOperand(unsigned idx, Value value) {
|
||||
assert(idx < getNumTrueOperands());
|
||||
setOperand(getTrueDestOperandIndex() + idx, value);
|
||||
}
|
||||
|
||||
unsigned getNumTrueOperands() { return getTrueOperands().size(); }
|
||||
|
||||
/// Erase the operand at 'index' from the true operand list.
|
||||
void eraseTrueOperand(unsigned index) {
|
||||
getTrueDestOperandsMutable().erase(index);
|
||||
}
|
||||
|
||||
// Accessors for operands to the 'false' destination.
|
||||
Value getFalseOperand(unsigned idx) {
|
||||
assert(idx < getNumFalseOperands());
|
||||
return getOperand(getFalseDestOperandIndex() + idx);
|
||||
}
|
||||
void setFalseOperand(unsigned idx, Value value) {
|
||||
assert(idx < getNumFalseOperands());
|
||||
setOperand(getFalseDestOperandIndex() + idx, value);
|
||||
}
|
||||
|
||||
operand_range getTrueOperands() { return getTrueDestOperands(); }
|
||||
operand_range getFalseOperands() { return getFalseDestOperands(); }
|
||||
|
||||
unsigned getNumFalseOperands() { return getFalseOperands().size(); }
|
||||
|
||||
/// Erase the operand at 'index' from the false operand list.
|
||||
void eraseFalseOperand(unsigned index) {
|
||||
getFalseDestOperandsMutable().erase(index);
|
||||
}
|
||||
|
||||
private:
|
||||
/// Get the index of the first true destination operand.
|
||||
unsigned getTrueDestOperandIndex() { return 1; }
|
||||
|
||||
/// Get the index of the first false destination operand.
|
||||
unsigned getFalseDestOperandIndex() {
|
||||
return getTrueDestOperandIndex() + getNumTrueOperands();
|
||||
}
|
||||
}];
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
let assemblyFormat = [{
|
||||
$condition `,`
|
||||
$trueDest (`(` $trueDestOperands^ `:` type($trueDestOperands) `)`)? `,`
|
||||
$falseDest (`(` $falseDestOperands^ `:` type($falseDestOperands) `)`)?
|
||||
attr-dict
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SwitchOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def SwitchOp : CF_Op<"switch",
|
||||
[AttrSizedOperandSegments,
|
||||
DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
|
||||
NoSideEffect, Terminator]> {
|
||||
let summary = "switch operation";
|
||||
let description = [{
|
||||
The `switch` terminator operation represents a switch on a signless integer
|
||||
value. If the flag matches one of the specified cases, then the
|
||||
corresponding destination is jumped to. If the flag does not match any of
|
||||
the cases, the default destination is jumped to. The count and types of
|
||||
operands must align with the arguments in the corresponding target blocks.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
switch %flag : i32, [
|
||||
default: ^bb1(%a : i32),
|
||||
42: ^bb1(%b : i32),
|
||||
43: ^bb3(%c : i32)
|
||||
]
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
AnyInteger:$flag,
|
||||
Variadic<AnyType>:$defaultOperands,
|
||||
VariadicOfVariadic<AnyType, "case_operand_segments">:$caseOperands,
|
||||
OptionalAttr<AnyIntElementsAttr>:$case_values,
|
||||
I32ElementsAttr:$case_operand_segments
|
||||
);
|
||||
let successors = (successor
|
||||
AnySuccessor:$defaultDestination,
|
||||
VariadicSuccessor<AnySuccessor>:$caseDestinations
|
||||
);
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value":$flag,
|
||||
"Block *":$defaultDestination,
|
||||
"ValueRange":$defaultOperands,
|
||||
CArg<"ArrayRef<APInt>", "{}">:$caseValues,
|
||||
CArg<"BlockRange", "{}">:$caseDestinations,
|
||||
CArg<"ArrayRef<ValueRange>", "{}">:$caseOperands)>,
|
||||
OpBuilder<(ins "Value":$flag,
|
||||
"Block *":$defaultDestination,
|
||||
"ValueRange":$defaultOperands,
|
||||
CArg<"ArrayRef<int32_t>", "{}">:$caseValues,
|
||||
CArg<"BlockRange", "{}">:$caseDestinations,
|
||||
CArg<"ArrayRef<ValueRange>", "{}">:$caseOperands)>,
|
||||
OpBuilder<(ins "Value":$flag,
|
||||
"Block *":$defaultDestination,
|
||||
"ValueRange":$defaultOperands,
|
||||
CArg<"DenseIntElementsAttr", "{}">:$caseValues,
|
||||
CArg<"BlockRange", "{}">:$caseDestinations,
|
||||
CArg<"ArrayRef<ValueRange>", "{}">:$caseOperands)>
|
||||
];
|
||||
|
||||
let assemblyFormat = [{
|
||||
$flag `:` type($flag) `,` `[` `\n`
|
||||
custom<SwitchOpCases>(ref(type($flag)),$defaultDestination,
|
||||
$defaultOperands,
|
||||
type($defaultOperands),
|
||||
$case_values,
|
||||
$caseDestinations,
|
||||
$caseOperands,
|
||||
type($caseOperands))
|
||||
`]`
|
||||
attr-dict
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
/// Return the operands for the case destination block at the given index.
|
||||
OperandRange getCaseOperands(unsigned index) {
|
||||
return getCaseOperands()[index];
|
||||
}
|
||||
|
||||
/// Return a mutable range of operands for the case destination block at the
|
||||
/// given index.
|
||||
MutableOperandRange getCaseOperandsMutable(unsigned index) {
|
||||
return getCaseOperandsMutable()[index];
|
||||
}
|
||||
}];
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
#endif // STANDARD_OPS
|
|
@ -84,15 +84,15 @@ def ExecuteRegionOp : SCF_Op<"execute_region"> {
|
|||
affine.for %i = 0 to 100 {
|
||||
"foo"() : () -> ()
|
||||
%v = scf.execute_region -> i64 {
|
||||
cond_br %cond, ^bb1, ^bb2
|
||||
cf.cond_br %cond, ^bb1, ^bb2
|
||||
|
||||
^bb1:
|
||||
%c1 = arith.constant 1 : i64
|
||||
br ^bb3(%c1 : i64)
|
||||
cf.br ^bb3(%c1 : i64)
|
||||
|
||||
^bb2:
|
||||
%c2 = arith.constant 2 : i64
|
||||
br ^bb3(%c2 : i64)
|
||||
cf.br ^bb3(%c2 : i64)
|
||||
|
||||
^bb3(%x : i64):
|
||||
scf.yield %x : i64
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
#ifndef MLIR_DIALECT_STANDARDOPS_IR_OPS_H
|
||||
#define MLIR_DIALECT_STANDARDOPS_IR_OPS_H
|
||||
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
|
@ -24,7 +24,6 @@
|
|||
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
||||
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
#include "mlir/Interfaces/VectorInterfaces.h"
|
||||
|
||||
// Pull in all enum type definitions and utility function declarations.
|
||||
#include "mlir/Dialect/StandardOps/IR/OpsEnums.h.inc"
|
||||
|
|
|
@ -20,12 +20,11 @@ include "mlir/Interfaces/CastInterfaces.td"
|
|||
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||
include "mlir/Interfaces/InferTypeOpInterface.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir/Interfaces/VectorInterfaces.td"
|
||||
|
||||
def StandardOps_Dialect : Dialect {
|
||||
let name = "std";
|
||||
let cppNamespace = "::mlir";
|
||||
let dependentDialects = ["arith::ArithmeticDialect"];
|
||||
let dependentDialects = ["cf::ControlFlowDialect"];
|
||||
let hasConstantMaterializer = 1;
|
||||
let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed;
|
||||
}
|
||||
|
@ -42,78 +41,6 @@ class Std_Op<string mnemonic, list<Trait> traits = []> :
|
|||
let parser = [{ return ::parse$cppClass(parser, result); }];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AssertOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def AssertOp : Std_Op<"assert"> {
|
||||
let summary = "Assert operation with message attribute";
|
||||
let description = [{
|
||||
Assert operation with single boolean operand and an error message attribute.
|
||||
If the argument is `true` this operation has no effect. Otherwise, the
|
||||
program execution will abort. The provided error message may be used by a
|
||||
runtime to propagate the error to the user.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
assert %b, "Expected ... to be true"
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins I1:$arg, StrAttr:$msg);
|
||||
|
||||
let assemblyFormat = "$arg `,` $msg attr-dict";
|
||||
let hasCanonicalizeMethod = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BranchOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def BranchOp : Std_Op<"br",
|
||||
[DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
|
||||
NoSideEffect, Terminator]> {
|
||||
let summary = "branch operation";
|
||||
let description = [{
|
||||
The `br` operation represents a branch operation in a function.
|
||||
The operation takes variable number of operands and produces no results.
|
||||
The operand number and types for each successor must match the arguments of
|
||||
the block successor.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
^bb2:
|
||||
%2 = call @someFn()
|
||||
br ^bb3(%2 : tensor<*xf32>)
|
||||
^bb3(%3: tensor<*xf32>):
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins Variadic<AnyType>:$destOperands);
|
||||
let successors = (successor AnySuccessor:$dest);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "Block *":$dest,
|
||||
CArg<"ValueRange", "{}">:$destOperands), [{
|
||||
$_state.addSuccessors(dest);
|
||||
$_state.addOperands(destOperands);
|
||||
}]>];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
void setDest(Block *block);
|
||||
|
||||
/// Erase the operand at 'index' from the operand list.
|
||||
void eraseOperand(unsigned index);
|
||||
}];
|
||||
|
||||
let hasCanonicalizeMethod = 1;
|
||||
let assemblyFormat = [{
|
||||
$dest (`(` $destOperands^ `:` type($destOperands) `)`)? attr-dict
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CallOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -246,121 +173,6 @@ def CallIndirectOp : Std_Op<"call_indirect", [
|
|||
"$callee `(` $callee_operands `)` attr-dict `:` type($callee)";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CondBranchOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def CondBranchOp : Std_Op<"cond_br",
|
||||
[AttrSizedOperandSegments,
|
||||
DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
|
||||
NoSideEffect, Terminator]> {
|
||||
let summary = "conditional branch operation";
|
||||
let description = [{
|
||||
The `cond_br` terminator operation represents a conditional branch on a
|
||||
boolean (1-bit integer) value. If the bit is set, then the first destination
|
||||
is jumped to; if it is false, the second destination is chosen. The count
|
||||
and types of operands must align with the arguments in the corresponding
|
||||
target blocks.
|
||||
|
||||
The MLIR conditional branch operation is not allowed to target the entry
|
||||
block for a region. The two destinations of the conditional branch operation
|
||||
are allowed to be the same.
|
||||
|
||||
The following example illustrates a function with a conditional branch
|
||||
operation that targets the same block.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
func @select(%a: i32, %b: i32, %flag: i1) -> i32 {
|
||||
// Both targets are the same, operands differ
|
||||
cond_br %flag, ^bb1(%a : i32), ^bb1(%b : i32)
|
||||
|
||||
^bb1(%x : i32) :
|
||||
return %x : i32
|
||||
}
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins I1:$condition,
|
||||
Variadic<AnyType>:$trueDestOperands,
|
||||
Variadic<AnyType>:$falseDestOperands);
|
||||
let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value":$condition, "Block *":$trueDest,
|
||||
"ValueRange":$trueOperands, "Block *":$falseDest,
|
||||
"ValueRange":$falseOperands), [{
|
||||
build($_builder, $_state, condition, trueOperands, falseOperands, trueDest,
|
||||
falseDest);
|
||||
}]>,
|
||||
OpBuilder<(ins "Value":$condition, "Block *":$trueDest,
|
||||
"Block *":$falseDest, CArg<"ValueRange", "{}">:$falseOperands), [{
|
||||
build($_builder, $_state, condition, trueDest, ValueRange(), falseDest,
|
||||
falseOperands);
|
||||
}]>];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// These are the indices into the dests list.
|
||||
enum { trueIndex = 0, falseIndex = 1 };
|
||||
|
||||
// Accessors for operands to the 'true' destination.
|
||||
Value getTrueOperand(unsigned idx) {
|
||||
assert(idx < getNumTrueOperands());
|
||||
return getOperand(getTrueDestOperandIndex() + idx);
|
||||
}
|
||||
|
||||
void setTrueOperand(unsigned idx, Value value) {
|
||||
assert(idx < getNumTrueOperands());
|
||||
setOperand(getTrueDestOperandIndex() + idx, value);
|
||||
}
|
||||
|
||||
unsigned getNumTrueOperands() { return getTrueOperands().size(); }
|
||||
|
||||
/// Erase the operand at 'index' from the true operand list.
|
||||
void eraseTrueOperand(unsigned index) {
|
||||
getTrueDestOperandsMutable().erase(index);
|
||||
}
|
||||
|
||||
// Accessors for operands to the 'false' destination.
|
||||
Value getFalseOperand(unsigned idx) {
|
||||
assert(idx < getNumFalseOperands());
|
||||
return getOperand(getFalseDestOperandIndex() + idx);
|
||||
}
|
||||
void setFalseOperand(unsigned idx, Value value) {
|
||||
assert(idx < getNumFalseOperands());
|
||||
setOperand(getFalseDestOperandIndex() + idx, value);
|
||||
}
|
||||
|
||||
operand_range getTrueOperands() { return getTrueDestOperands(); }
|
||||
operand_range getFalseOperands() { return getFalseDestOperands(); }
|
||||
|
||||
unsigned getNumFalseOperands() { return getFalseOperands().size(); }
|
||||
|
||||
/// Erase the operand at 'index' from the false operand list.
|
||||
void eraseFalseOperand(unsigned index) {
|
||||
getFalseDestOperandsMutable().erase(index);
|
||||
}
|
||||
|
||||
private:
|
||||
/// Get the index of the first true destination operand.
|
||||
unsigned getTrueDestOperandIndex() { return 1; }
|
||||
|
||||
/// Get the index of the first false destination operand.
|
||||
unsigned getFalseDestOperandIndex() {
|
||||
return getTrueDestOperandIndex() + getNumTrueOperands();
|
||||
}
|
||||
}];
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
let assemblyFormat = [{
|
||||
$condition `,`
|
||||
$trueDest (`(` $trueDestOperands^ `:` type($trueDestOperands) `)`)? `,`
|
||||
$falseDest (`(` $falseDestOperands^ `:` type($falseDestOperands) `)`)?
|
||||
attr-dict
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ConstantOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -443,93 +255,4 @@ def ReturnOp : Std_Op<"return", [NoSideEffect, HasParent<"FuncOp">,
|
|||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SwitchOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def SwitchOp : Std_Op<"switch",
|
||||
[AttrSizedOperandSegments,
|
||||
DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
|
||||
NoSideEffect, Terminator]> {
|
||||
let summary = "switch operation";
|
||||
let description = [{
|
||||
The `switch` terminator operation represents a switch on a signless integer
|
||||
value. If the flag matches one of the specified cases, then the
|
||||
corresponding destination is jumped to. If the flag does not match any of
|
||||
the cases, the default destination is jumped to. The count and types of
|
||||
operands must align with the arguments in the corresponding target blocks.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
switch %flag : i32, [
|
||||
default: ^bb1(%a : i32),
|
||||
42: ^bb1(%b : i32),
|
||||
43: ^bb3(%c : i32)
|
||||
]
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
AnyInteger:$flag,
|
||||
Variadic<AnyType>:$defaultOperands,
|
||||
VariadicOfVariadic<AnyType, "case_operand_segments">:$caseOperands,
|
||||
OptionalAttr<AnyIntElementsAttr>:$case_values,
|
||||
I32ElementsAttr:$case_operand_segments
|
||||
);
|
||||
let successors = (successor
|
||||
AnySuccessor:$defaultDestination,
|
||||
VariadicSuccessor<AnySuccessor>:$caseDestinations
|
||||
);
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value":$flag,
|
||||
"Block *":$defaultDestination,
|
||||
"ValueRange":$defaultOperands,
|
||||
CArg<"ArrayRef<APInt>", "{}">:$caseValues,
|
||||
CArg<"BlockRange", "{}">:$caseDestinations,
|
||||
CArg<"ArrayRef<ValueRange>", "{}">:$caseOperands)>,
|
||||
OpBuilder<(ins "Value":$flag,
|
||||
"Block *":$defaultDestination,
|
||||
"ValueRange":$defaultOperands,
|
||||
CArg<"ArrayRef<int32_t>", "{}">:$caseValues,
|
||||
CArg<"BlockRange", "{}">:$caseDestinations,
|
||||
CArg<"ArrayRef<ValueRange>", "{}">:$caseOperands)>,
|
||||
OpBuilder<(ins "Value":$flag,
|
||||
"Block *":$defaultDestination,
|
||||
"ValueRange":$defaultOperands,
|
||||
CArg<"DenseIntElementsAttr", "{}">:$caseValues,
|
||||
CArg<"BlockRange", "{}">:$caseDestinations,
|
||||
CArg<"ArrayRef<ValueRange>", "{}">:$caseOperands)>
|
||||
];
|
||||
|
||||
let assemblyFormat = [{
|
||||
$flag `:` type($flag) `,` `[` `\n`
|
||||
custom<SwitchOpCases>(ref(type($flag)),$defaultDestination,
|
||||
$defaultOperands,
|
||||
type($defaultOperands),
|
||||
$case_values,
|
||||
$caseDestinations,
|
||||
$caseOperands,
|
||||
type($caseOperands))
|
||||
`]`
|
||||
attr-dict
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
/// Return the operands for the case destination block at the given index.
|
||||
OperandRange getCaseOperands(unsigned index) {
|
||||
return getCaseOperands()[index];
|
||||
}
|
||||
|
||||
/// Return a mutable range of operands for the case destination block at the
|
||||
/// given index.
|
||||
MutableOperandRange getCaseOperandsMutable(unsigned index) {
|
||||
return getCaseOperandsMutable()[index];
|
||||
}
|
||||
}];
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
#endif // STANDARD_OPS
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include "mlir/Dialect/Async/IR/Async.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||
#include "mlir/Dialect/Complex/IR/Complex.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
|
||||
#include "mlir/Dialect/DLTI/DLTI.h"
|
||||
#include "mlir/Dialect/EmitC/IR/EmitC.h"
|
||||
#include "mlir/Dialect/GPU/GPUDialect.h"
|
||||
|
@ -61,6 +62,7 @@ inline void registerAllDialects(DialectRegistry ®istry) {
|
|||
arm_neon::ArmNeonDialect,
|
||||
async::AsyncDialect,
|
||||
bufferization::BufferizationDialect,
|
||||
cf::ControlFlowDialect,
|
||||
complex::ComplexDialect,
|
||||
DLTIDialect,
|
||||
emitc::EmitCDialect,
|
||||
|
|
|
@ -6,6 +6,8 @@ add_subdirectory(AsyncToLLVM)
|
|||
add_subdirectory(BufferizationToMemRef)
|
||||
add_subdirectory(ComplexToLLVM)
|
||||
add_subdirectory(ComplexToStandard)
|
||||
add_subdirectory(ControlFlowToLLVM)
|
||||
add_subdirectory(ControlFlowToSPIRV)
|
||||
add_subdirectory(GPUCommon)
|
||||
add_subdirectory(GPUToNVVM)
|
||||
add_subdirectory(GPUToROCDL)
|
||||
|
@ -25,10 +27,10 @@ add_subdirectory(OpenACCToSCF)
|
|||
add_subdirectory(OpenMPToLLVM)
|
||||
add_subdirectory(PDLToPDLInterp)
|
||||
add_subdirectory(ReconcileUnrealizedCasts)
|
||||
add_subdirectory(SCFToControlFlow)
|
||||
add_subdirectory(SCFToGPU)
|
||||
add_subdirectory(SCFToOpenMP)
|
||||
add_subdirectory(SCFToSPIRV)
|
||||
add_subdirectory(SCFToStandard)
|
||||
add_subdirectory(ShapeToStandard)
|
||||
add_subdirectory(SPIRVToLLVM)
|
||||
add_subdirectory(StandardToLLVM)
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
add_mlir_conversion_library(MLIRControlFlowToLLVM
|
||||
ControlFlowToLLVM.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ControlFlowToLLVM
|
||||
|
||||
DEPENDS
|
||||
MLIRConversionPassIncGen
|
||||
intrinsics_gen
|
||||
|
||||
LINK_COMPONENTS
|
||||
Core
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRAnalysis
|
||||
MLIRControlFlow
|
||||
MLIRLLVMCommonConversion
|
||||
MLIRLLVMIR
|
||||
MLIRPass
|
||||
MLIRTransformUtils
|
||||
)
|
|
@ -0,0 +1,148 @@
|
|||
//===- ControlFlowToLLVM.cpp - ControlFlow to LLVM dialect conversion -----===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements a pass to convert MLIR standard and builtin dialects
|
||||
// into the LLVM IR dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
|
||||
#include "../PassDetail.h"
|
||||
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
|
||||
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
||||
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include <functional>
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
#define PASS_NAME "convert-cf-to-llvm"
|
||||
|
||||
namespace {
|
||||
/// Lower `std.assert`. The default lowering calls the `abort` function if the
|
||||
/// assertion is violated and has no effect otherwise. The failure message is
|
||||
/// ignored by the default lowering but should be propagated by any custom
|
||||
/// lowering.
|
||||
struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
|
||||
using ConvertOpToLLVMPattern<cf::AssertOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op.getLoc();
|
||||
|
||||
// Insert the `abort` declaration if necessary.
|
||||
auto module = op->getParentOfType<ModuleOp>();
|
||||
auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
|
||||
if (!abortFunc) {
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(module.getBody());
|
||||
auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {});
|
||||
abortFunc = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(),
|
||||
"abort", abortFuncTy);
|
||||
}
|
||||
|
||||
// Split block at `assert` operation.
|
||||
Block *opBlock = rewriter.getInsertionBlock();
|
||||
auto opPosition = rewriter.getInsertionPoint();
|
||||
Block *continuationBlock = rewriter.splitBlock(opBlock, opPosition);
|
||||
|
||||
// Generate IR to call `abort`.
|
||||
Block *failureBlock = rewriter.createBlock(opBlock->getParent());
|
||||
rewriter.create<LLVM::CallOp>(loc, abortFunc, llvm::None);
|
||||
rewriter.create<LLVM::UnreachableOp>(loc);
|
||||
|
||||
// Generate assertion test.
|
||||
rewriter.setInsertionPointToEnd(opBlock);
|
||||
rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
|
||||
op, adaptor.getArg(), continuationBlock, failureBlock);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Base class for LLVM IR lowering terminator operations with successors.
|
||||
template <typename SourceOp, typename TargetOp>
|
||||
struct OneToOneLLVMTerminatorLowering
|
||||
: public ConvertOpToLLVMPattern<SourceOp> {
|
||||
using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
|
||||
using Base = OneToOneLLVMTerminatorLowering<SourceOp, TargetOp>;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<TargetOp>(op, adaptor.getOperands(),
|
||||
op->getSuccessors(), op->getAttrs());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// FIXME: this should be tablegen'ed as well.
|
||||
struct BranchOpLowering
|
||||
: public OneToOneLLVMTerminatorLowering<cf::BranchOp, LLVM::BrOp> {
|
||||
using Base::Base;
|
||||
};
|
||||
struct CondBranchOpLowering
|
||||
: public OneToOneLLVMTerminatorLowering<cf::CondBranchOp, LLVM::CondBrOp> {
|
||||
using Base::Base;
|
||||
};
|
||||
struct SwitchOpLowering
|
||||
: public OneToOneLLVMTerminatorLowering<cf::SwitchOp, LLVM::SwitchOp> {
|
||||
using Base::Base;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void mlir::cf::populateControlFlowToLLVMConversionPatterns(
|
||||
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
|
||||
// clang-format off
|
||||
patterns.add<
|
||||
AssertOpLowering,
|
||||
BranchOpLowering,
|
||||
CondBranchOpLowering,
|
||||
SwitchOpLowering>(converter);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Pass Definition
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
/// A pass converting MLIR operations into the LLVM IR dialect.
|
||||
struct ConvertControlFlowToLLVM
|
||||
: public ConvertControlFlowToLLVMBase<ConvertControlFlowToLLVM> {
|
||||
ConvertControlFlowToLLVM() = default;
|
||||
|
||||
/// Run the dialect converter on the module.
|
||||
void runOnOperation() override {
|
||||
LLVMConversionTarget target(getContext());
|
||||
RewritePatternSet patterns(&getContext());
|
||||
|
||||
LowerToLLVMOptions options(&getContext());
|
||||
if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
|
||||
options.overrideIndexBitwidth(indexBitwidth);
|
||||
|
||||
LLVMTypeConverter converter(&getContext(), options);
|
||||
mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns))))
|
||||
signalPassFailure();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> mlir::cf::createConvertControlFlowToLLVMPass() {
|
||||
return std::make_unique<ConvertControlFlowToLLVM>();
|
||||
}
|
|
@ -0,0 +1,19 @@
|
|||
add_mlir_conversion_library(MLIRControlFlowToSPIRV
|
||||
ControlFlowToSPIRV.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SPIRV
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/IR
|
||||
|
||||
DEPENDS
|
||||
MLIRConversionPassIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRControlFlow
|
||||
MLIRPass
|
||||
MLIRSPIRV
|
||||
MLIRSPIRVConversion
|
||||
MLIRSupport
|
||||
MLIRTransformUtils
|
||||
)
|
|
@ -0,0 +1,73 @@
|
|||
//===- ControlFlowToSPIRV.cpp - ControlFlow to SPIR-V Patterns ------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements patterns to convert standard dialect to SPIR-V dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h"
|
||||
#include "../SPIRVCommon/Pattern.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
|
||||
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
|
||||
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
|
||||
#include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
||||
#define DEBUG_TYPE "cf-to-spirv-pattern"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Operation conversion
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
|
||||
/// Converts cf.br to spv.Branch.
|
||||
struct BranchOpPattern final : public OpConversionPattern<cf::BranchOp> {
|
||||
using OpConversionPattern<cf::BranchOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(cf::BranchOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<spirv::BranchOp>(op, op.getDest(),
|
||||
adaptor.getDestOperands());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Converts cf.cond_br to spv.BranchConditional.
|
||||
struct CondBranchOpPattern final
|
||||
: public OpConversionPattern<cf::CondBranchOp> {
|
||||
using OpConversionPattern<cf::CondBranchOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(cf::CondBranchOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<spirv::BranchConditionalOp>(
|
||||
op, op.getCondition(), op.getTrueDest(), adaptor.getTrueDestOperands(),
|
||||
op.getFalseDest(), adaptor.getFalseDestOperands());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Pattern population
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void mlir::cf::populateControlFlowToSPIRVPatterns(
|
||||
SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
|
||||
patterns.add<BranchOpPattern, CondBranchOpPattern>(typeConverter, context);
|
||||
}
|
|
@ -14,6 +14,7 @@
|
|||
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
|
||||
|
||||
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
|
||||
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
|
||||
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
|
||||
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
|
||||
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
|
||||
|
@ -172,8 +173,8 @@ struct LowerGpuOpsToNVVMOpsPass
|
|||
populateGpuRewritePatterns(patterns);
|
||||
(void)applyPatternsAndFoldGreedily(m, std::move(patterns));
|
||||
|
||||
mlir::arith::populateArithmeticToLLVMConversionPatterns(converter,
|
||||
llvmPatterns);
|
||||
arith::populateArithmeticToLLVMConversionPatterns(converter, llvmPatterns);
|
||||
cf::populateControlFlowToLLVMConversionPatterns(converter, llvmPatterns);
|
||||
populateStdToLLVMConversionPatterns(converter, llvmPatterns);
|
||||
populateMemRefToLLVMConversionPatterns(converter, llvmPatterns);
|
||||
populateGpuToNVVMConversionPatterns(converter, llvmPatterns);
|
||||
|
|
|
@ -19,7 +19,7 @@ add_mlir_conversion_library(MLIRLinalgToLLVM
|
|||
MLIRLLVMCommonConversion
|
||||
MLIRLLVMIR
|
||||
MLIRMemRefToLLVM
|
||||
MLIRSCFToStandard
|
||||
MLIRSCFToControlFlow
|
||||
MLIRTransforms
|
||||
MLIRVectorToLLVM
|
||||
MLIRVectorToSCF
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
||||
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
|
||||
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
|
||||
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
|
||||
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
|
||||
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
|
||||
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
|
|
|
@ -433,7 +433,7 @@ struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
|
|||
/// +---------------------------------+
|
||||
/// | <code before the AtomicRMWOp> |
|
||||
/// | <compute initial %loaded> |
|
||||
/// | br loop(%loaded) |
|
||||
/// | cf.br loop(%loaded) |
|
||||
/// +---------------------------------+
|
||||
/// |
|
||||
/// -------| |
|
||||
|
@ -444,7 +444,7 @@ struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
|
|||
/// | | %pair = cmpxchg |
|
||||
/// | | %ok = %pair[0] |
|
||||
/// | | %new = %pair[1] |
|
||||
/// | | cond_br %ok, end, loop(%new) |
|
||||
/// | | cf.cond_br %ok, end, loop(%new) |
|
||||
/// | +--------------------------------+
|
||||
/// | | |
|
||||
/// |----------- |
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
|
||||
#include "../PassDetail.h"
|
||||
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
|
||||
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
|
||||
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
|
||||
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
||||
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
|
||||
|
@ -66,7 +67,8 @@ void ConvertOpenMPToLLVMPass::runOnOperation() {
|
|||
// Convert to OpenMP operations with LLVM IR dialect
|
||||
RewritePatternSet patterns(&getContext());
|
||||
LLVMTypeConverter converter(&getContext());
|
||||
mlir::arith::populateArithmeticToLLVMConversionPatterns(converter, patterns);
|
||||
arith::populateArithmeticToLLVMConversionPatterns(converter, patterns);
|
||||
cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
|
||||
populateMemRefToLLVMConversionPatterns(converter, patterns);
|
||||
populateStdToLLVMConversionPatterns(converter, patterns);
|
||||
populateOpenMPToLLVMConversionPatterns(converter, patterns);
|
||||
|
|
|
@ -29,6 +29,10 @@ namespace arith {
|
|||
class ArithmeticDialect;
|
||||
} // namespace arith
|
||||
|
||||
namespace cf {
|
||||
class ControlFlowDialect;
|
||||
} // namespace cf
|
||||
|
||||
namespace complex {
|
||||
class ComplexDialect;
|
||||
} // namespace complex
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
add_mlir_conversion_library(MLIRSCFToStandard
|
||||
SCFToStandard.cpp
|
||||
add_mlir_conversion_library(MLIRSCFToControlFlow
|
||||
SCFToControlFlow.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/SCFToStandard
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/SCFToControlFlow
|
||||
|
||||
DEPENDS
|
||||
MLIRConversionPassIncGen
|
||||
|
@ -12,6 +12,7 @@ add_mlir_conversion_library(MLIRSCFToStandard
|
|||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRArithmetic
|
||||
MLIRControlFlow
|
||||
MLIRSCF
|
||||
MLIRTransforms
|
||||
)
|
|
@ -1,4 +1,4 @@
|
|||
//===- SCFToStandard.cpp - ControlFlow to CFG conversion ------------------===//
|
||||
//===- SCFToControlFlow.cpp - SCF to CF conversion ------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
|
@ -11,11 +11,11 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
|
||||
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
|
||||
#include "../PassDetail.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
|
@ -29,7 +29,8 @@ using namespace mlir::scf;
|
|||
|
||||
namespace {
|
||||
|
||||
struct SCFToStandardPass : public SCFToStandardBase<SCFToStandardPass> {
|
||||
struct SCFToControlFlowPass
|
||||
: public SCFToControlFlowBase<SCFToControlFlowPass> {
|
||||
void runOnOperation() override;
|
||||
};
|
||||
|
||||
|
@ -57,7 +58,7 @@ struct SCFToStandardPass : public SCFToStandardBase<SCFToStandardPass> {
|
|||
// | <code before the ForOp> |
|
||||
// | <definitions of %init...> |
|
||||
// | <compute initial %iv value> |
|
||||
// | br cond(%iv, %init...) |
|
||||
// | cf.br cond(%iv, %init...) |
|
||||
// +---------------------------------+
|
||||
// |
|
||||
// -------| |
|
||||
|
@ -65,7 +66,7 @@ struct SCFToStandardPass : public SCFToStandardBase<SCFToStandardPass> {
|
|||
// | +--------------------------------+
|
||||
// | | cond(%iv, %init...): |
|
||||
// | | <compare %iv to upper bound> |
|
||||
// | | cond_br %r, body, end |
|
||||
// | | cf.cond_br %r, body, end |
|
||||
// | +--------------------------------+
|
||||
// | | |
|
||||
// | | -------------|
|
||||
|
@ -83,7 +84,7 @@ struct SCFToStandardPass : public SCFToStandardBase<SCFToStandardPass> {
|
|||
// | | <body contents> | |
|
||||
// | | <operands of yield = %yields>| |
|
||||
// | | %new_iv =<add step to %iv> | |
|
||||
// | | br cond(%new_iv, %yields) | |
|
||||
// | | cf.br cond(%new_iv, %yields) | |
|
||||
// | +--------------------------------+ |
|
||||
// | | |
|
||||
// |----------- |--------------------
|
||||
|
@ -125,7 +126,7 @@ struct ForLowering : public OpRewritePattern<ForOp> {
|
|||
//
|
||||
// +--------------------------------+
|
||||
// | <code before the IfOp> |
|
||||
// | cond_br %cond, %then, %else |
|
||||
// | cf.cond_br %cond, %then, %else |
|
||||
// +--------------------------------+
|
||||
// | |
|
||||
// | --------------|
|
||||
|
@ -133,7 +134,7 @@ struct ForLowering : public OpRewritePattern<ForOp> {
|
|||
// +--------------------------------+ |
|
||||
// | then: | |
|
||||
// | <then contents> | |
|
||||
// | br continue | |
|
||||
// | cf.br continue | |
|
||||
// +--------------------------------+ |
|
||||
// | |
|
||||
// |---------- |-------------
|
||||
|
@ -141,7 +142,7 @@ struct ForLowering : public OpRewritePattern<ForOp> {
|
|||
// | +--------------------------------+
|
||||
// | | else: |
|
||||
// | | <else contents> |
|
||||
// | | br continue |
|
||||
// | | cf.br continue |
|
||||
// | +--------------------------------+
|
||||
// | |
|
||||
// ------| |
|
||||
|
@ -155,7 +156,7 @@ struct ForLowering : public OpRewritePattern<ForOp> {
|
|||
//
|
||||
// +--------------------------------+
|
||||
// | <code before the IfOp> |
|
||||
// | cond_br %cond, %then, %else |
|
||||
// | cf.cond_br %cond, %then, %else |
|
||||
// +--------------------------------+
|
||||
// | |
|
||||
// | --------------|
|
||||
|
@ -163,7 +164,7 @@ struct ForLowering : public OpRewritePattern<ForOp> {
|
|||
// +--------------------------------+ |
|
||||
// | then: | |
|
||||
// | <then contents> | |
|
||||
// | br dom(%args...) | |
|
||||
// | cf.br dom(%args...) | |
|
||||
// +--------------------------------+ |
|
||||
// | |
|
||||
// |---------- |-------------
|
||||
|
@ -171,14 +172,14 @@ struct ForLowering : public OpRewritePattern<ForOp> {
|
|||
// | +--------------------------------+
|
||||
// | | else: |
|
||||
// | | <else contents> |
|
||||
// | | br dom(%args...) |
|
||||
// | | cf.br dom(%args...) |
|
||||
// | +--------------------------------+
|
||||
// | |
|
||||
// ------| |
|
||||
// v v
|
||||
// +--------------------------------+
|
||||
// | dom(%args...): |
|
||||
// | br continue |
|
||||
// | cf.br continue |
|
||||
// +--------------------------------+
|
||||
// |
|
||||
// v
|
||||
|
@ -218,7 +219,7 @@ struct ParallelLowering : public OpRewritePattern<mlir::scf::ParallelOp> {
|
|||
///
|
||||
/// +---------------------------------+
|
||||
/// | <code before the WhileOp> |
|
||||
/// | br ^before(%operands...) |
|
||||
/// | cf.br ^before(%operands...) |
|
||||
/// +---------------------------------+
|
||||
/// |
|
||||
/// -------| |
|
||||
|
@ -233,7 +234,7 @@ struct ParallelLowering : public OpRewritePattern<mlir::scf::ParallelOp> {
|
|||
/// | +--------------------------------+
|
||||
/// | | ^before-last:
|
||||
/// | | %cond = <compute condition> |
|
||||
/// | | cond_br %cond, |
|
||||
/// | | cf.cond_br %cond, |
|
||||
/// | | ^after(%vals...), ^cont |
|
||||
/// | +--------------------------------+
|
||||
/// | | |
|
||||
|
@ -249,7 +250,7 @@ struct ParallelLowering : public OpRewritePattern<mlir::scf::ParallelOp> {
|
|||
/// | +--------------------------------+ |
|
||||
/// | | ^after-last: | |
|
||||
/// | | %yields... = <some payload> | |
|
||||
/// | | br ^before(%yields...) | |
|
||||
/// | | cf.br ^before(%yields...) | |
|
||||
/// | +--------------------------------+ |
|
||||
/// | | |
|
||||
/// |----------- |--------------------
|
||||
|
@ -321,7 +322,7 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
|
|||
SmallVector<Value, 8> loopCarried;
|
||||
loopCarried.push_back(stepped);
|
||||
loopCarried.append(terminator->operand_begin(), terminator->operand_end());
|
||||
rewriter.create<BranchOp>(loc, conditionBlock, loopCarried);
|
||||
rewriter.create<cf::BranchOp>(loc, conditionBlock, loopCarried);
|
||||
rewriter.eraseOp(terminator);
|
||||
|
||||
// Compute loop bounds before branching to the condition.
|
||||
|
@ -337,15 +338,16 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
|
|||
destOperands.push_back(lowerBound);
|
||||
auto iterOperands = forOp.getIterOperands();
|
||||
destOperands.append(iterOperands.begin(), iterOperands.end());
|
||||
rewriter.create<BranchOp>(loc, conditionBlock, destOperands);
|
||||
rewriter.create<cf::BranchOp>(loc, conditionBlock, destOperands);
|
||||
|
||||
// With the body block done, we can fill in the condition block.
|
||||
rewriter.setInsertionPointToEnd(conditionBlock);
|
||||
auto comparison = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::slt, iv, upperBound);
|
||||
|
||||
rewriter.create<CondBranchOp>(loc, comparison, firstBodyBlock,
|
||||
ArrayRef<Value>(), endBlock, ArrayRef<Value>());
|
||||
rewriter.create<cf::CondBranchOp>(loc, comparison, firstBodyBlock,
|
||||
ArrayRef<Value>(), endBlock,
|
||||
ArrayRef<Value>());
|
||||
// The result of the loop operation is the values of the condition block
|
||||
// arguments except the induction variable on the last iteration.
|
||||
rewriter.replaceOp(forOp, conditionBlock->getArguments().drop_front());
|
||||
|
@ -369,7 +371,7 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
|
|||
continueBlock =
|
||||
rewriter.createBlock(remainingOpsBlock, ifOp.getResultTypes(),
|
||||
SmallVector<Location>(ifOp.getNumResults(), loc));
|
||||
rewriter.create<BranchOp>(loc, remainingOpsBlock);
|
||||
rewriter.create<cf::BranchOp>(loc, remainingOpsBlock);
|
||||
}
|
||||
|
||||
// Move blocks from the "then" region to the region containing 'scf.if',
|
||||
|
@ -379,7 +381,7 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
|
|||
Operation *thenTerminator = thenRegion.back().getTerminator();
|
||||
ValueRange thenTerminatorOperands = thenTerminator->getOperands();
|
||||
rewriter.setInsertionPointToEnd(&thenRegion.back());
|
||||
rewriter.create<BranchOp>(loc, continueBlock, thenTerminatorOperands);
|
||||
rewriter.create<cf::BranchOp>(loc, continueBlock, thenTerminatorOperands);
|
||||
rewriter.eraseOp(thenTerminator);
|
||||
rewriter.inlineRegionBefore(thenRegion, continueBlock);
|
||||
|
||||
|
@ -393,15 +395,15 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
|
|||
Operation *elseTerminator = elseRegion.back().getTerminator();
|
||||
ValueRange elseTerminatorOperands = elseTerminator->getOperands();
|
||||
rewriter.setInsertionPointToEnd(&elseRegion.back());
|
||||
rewriter.create<BranchOp>(loc, continueBlock, elseTerminatorOperands);
|
||||
rewriter.create<cf::BranchOp>(loc, continueBlock, elseTerminatorOperands);
|
||||
rewriter.eraseOp(elseTerminator);
|
||||
rewriter.inlineRegionBefore(elseRegion, continueBlock);
|
||||
}
|
||||
|
||||
rewriter.setInsertionPointToEnd(condBlock);
|
||||
rewriter.create<CondBranchOp>(loc, ifOp.getCondition(), thenBlock,
|
||||
/*trueArgs=*/ArrayRef<Value>(), elseBlock,
|
||||
/*falseArgs=*/ArrayRef<Value>());
|
||||
rewriter.create<cf::CondBranchOp>(loc, ifOp.getCondition(), thenBlock,
|
||||
/*trueArgs=*/ArrayRef<Value>(), elseBlock,
|
||||
/*falseArgs=*/ArrayRef<Value>());
|
||||
|
||||
// Ok, we're done!
|
||||
rewriter.replaceOp(ifOp, continueBlock->getArguments());
|
||||
|
@ -419,13 +421,13 @@ ExecuteRegionLowering::matchAndRewrite(ExecuteRegionOp op,
|
|||
|
||||
auto ®ion = op.getRegion();
|
||||
rewriter.setInsertionPointToEnd(condBlock);
|
||||
rewriter.create<BranchOp>(loc, ®ion.front());
|
||||
rewriter.create<cf::BranchOp>(loc, ®ion.front());
|
||||
|
||||
for (Block &block : region) {
|
||||
if (auto terminator = dyn_cast<scf::YieldOp>(block.getTerminator())) {
|
||||
ValueRange terminatorOperands = terminator->getOperands();
|
||||
rewriter.setInsertionPointToEnd(&block);
|
||||
rewriter.create<BranchOp>(loc, remainingOpsBlock, terminatorOperands);
|
||||
rewriter.create<cf::BranchOp>(loc, remainingOpsBlock, terminatorOperands);
|
||||
rewriter.eraseOp(terminator);
|
||||
}
|
||||
}
|
||||
|
@ -538,20 +540,21 @@ LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp,
|
|||
|
||||
// Branch to the "before" region.
|
||||
rewriter.setInsertionPointToEnd(currentBlock);
|
||||
rewriter.create<BranchOp>(loc, before, whileOp.getInits());
|
||||
rewriter.create<cf::BranchOp>(loc, before, whileOp.getInits());
|
||||
|
||||
// Replace terminators with branches. Assuming bodies are SESE, which holds
|
||||
// given only the patterns from this file, we only need to look at the last
|
||||
// block. This should be reconsidered if we allow break/continue in SCF.
|
||||
rewriter.setInsertionPointToEnd(beforeLast);
|
||||
auto condOp = cast<ConditionOp>(beforeLast->getTerminator());
|
||||
rewriter.replaceOpWithNewOp<CondBranchOp>(condOp, condOp.getCondition(),
|
||||
after, condOp.getArgs(),
|
||||
continuation, ValueRange());
|
||||
rewriter.replaceOpWithNewOp<cf::CondBranchOp>(condOp, condOp.getCondition(),
|
||||
after, condOp.getArgs(),
|
||||
continuation, ValueRange());
|
||||
|
||||
rewriter.setInsertionPointToEnd(afterLast);
|
||||
auto yieldOp = cast<scf::YieldOp>(afterLast->getTerminator());
|
||||
rewriter.replaceOpWithNewOp<BranchOp>(yieldOp, before, yieldOp.getResults());
|
||||
rewriter.replaceOpWithNewOp<cf::BranchOp>(yieldOp, before,
|
||||
yieldOp.getResults());
|
||||
|
||||
// Replace the op with values "yielded" from the "before" region, which are
|
||||
// visible by dominance.
|
||||
|
@ -593,14 +596,14 @@ DoWhileLowering::matchAndRewrite(WhileOp whileOp,
|
|||
|
||||
// Branch to the "before" region.
|
||||
rewriter.setInsertionPointToEnd(currentBlock);
|
||||
rewriter.create<BranchOp>(whileOp.getLoc(), before, whileOp.getInits());
|
||||
rewriter.create<cf::BranchOp>(whileOp.getLoc(), before, whileOp.getInits());
|
||||
|
||||
// Loop around the "before" region based on condition.
|
||||
rewriter.setInsertionPointToEnd(beforeLast);
|
||||
auto condOp = cast<ConditionOp>(beforeLast->getTerminator());
|
||||
rewriter.replaceOpWithNewOp<CondBranchOp>(condOp, condOp.getCondition(),
|
||||
before, condOp.getArgs(),
|
||||
continuation, ValueRange());
|
||||
rewriter.replaceOpWithNewOp<cf::CondBranchOp>(condOp, condOp.getCondition(),
|
||||
before, condOp.getArgs(),
|
||||
continuation, ValueRange());
|
||||
|
||||
// Replace the op with values "yielded" from the "before" region, which are
|
||||
// visible by dominance.
|
||||
|
@ -609,17 +612,18 @@ DoWhileLowering::matchAndRewrite(WhileOp whileOp,
|
|||
return success();
|
||||
}
|
||||
|
||||
void mlir::populateLoopToStdConversionPatterns(RewritePatternSet &patterns) {
|
||||
void mlir::populateSCFToControlFlowConversionPatterns(
|
||||
RewritePatternSet &patterns) {
|
||||
patterns.add<ForLowering, IfLowering, ParallelLowering, WhileLowering,
|
||||
ExecuteRegionLowering>(patterns.getContext());
|
||||
patterns.add<DoWhileLowering>(patterns.getContext(), /*benefit=*/2);
|
||||
}
|
||||
|
||||
void SCFToStandardPass::runOnOperation() {
|
||||
void SCFToControlFlowPass::runOnOperation() {
|
||||
RewritePatternSet patterns(&getContext());
|
||||
populateLoopToStdConversionPatterns(patterns);
|
||||
// Configure conversion to lower out scf.for, scf.if, scf.parallel and
|
||||
// scf.while. Anything else is fine.
|
||||
populateSCFToControlFlowConversionPatterns(patterns);
|
||||
|
||||
// Configure conversion to lower out SCF operations.
|
||||
ConversionTarget target(getContext());
|
||||
target.addIllegalOp<scf::ForOp, scf::IfOp, scf::ParallelOp, scf::WhileOp,
|
||||
scf::ExecuteRegionOp>();
|
||||
|
@ -629,6 +633,6 @@ void SCFToStandardPass::runOnOperation() {
|
|||
signalPassFailure();
|
||||
}
|
||||
|
||||
std::unique_ptr<Pass> mlir::createLowerToCFGPass() {
|
||||
return std::make_unique<SCFToStandardPass>();
|
||||
std::unique_ptr<Pass> mlir::createConvertSCFToCFPass() {
|
||||
return std::make_unique<SCFToControlFlowPass>();
|
||||
}
|
|
@ -9,6 +9,7 @@
|
|||
#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
|
||||
|
||||
#include "../PassDetail.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
|
@ -29,7 +30,7 @@ public:
|
|||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(shape::CstrRequireOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
rewriter.create<AssertOp>(op.getLoc(), op.getPred(), op.getMsgAttr());
|
||||
rewriter.create<cf::AssertOp>(op.getLoc(), op.getPred(), op.getMsgAttr());
|
||||
rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true);
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -14,6 +14,7 @@ add_mlir_conversion_library(MLIRStandardToLLVM
|
|||
LINK_LIBS PUBLIC
|
||||
MLIRAnalysis
|
||||
MLIRArithmeticToLLVM
|
||||
MLIRControlFlowToLLVM
|
||||
MLIRDataLayoutInterfaces
|
||||
MLIRLLVMCommonConversion
|
||||
MLIRLLVMIR
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
#include "../PassDetail.h"
|
||||
#include "mlir/Analysis/DataLayoutAnalysis.h"
|
||||
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
|
||||
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
|
||||
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
|
||||
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
||||
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
|
||||
|
@ -387,48 +388,6 @@ struct BarePtrFuncOpConversion : public FuncOpConversionBase {
|
|||
}
|
||||
};
|
||||
|
||||
/// Lower `std.assert`. The default lowering calls the `abort` function if the
|
||||
/// assertion is violated and has no effect otherwise. The failure message is
|
||||
/// ignored by the default lowering but should be propagated by any custom
|
||||
/// lowering.
|
||||
struct AssertOpLowering : public ConvertOpToLLVMPattern<AssertOp> {
|
||||
using ConvertOpToLLVMPattern<AssertOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(AssertOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op.getLoc();
|
||||
|
||||
// Insert the `abort` declaration if necessary.
|
||||
auto module = op->getParentOfType<ModuleOp>();
|
||||
auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
|
||||
if (!abortFunc) {
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(module.getBody());
|
||||
auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {});
|
||||
abortFunc = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(),
|
||||
"abort", abortFuncTy);
|
||||
}
|
||||
|
||||
// Split block at `assert` operation.
|
||||
Block *opBlock = rewriter.getInsertionBlock();
|
||||
auto opPosition = rewriter.getInsertionPoint();
|
||||
Block *continuationBlock = rewriter.splitBlock(opBlock, opPosition);
|
||||
|
||||
// Generate IR to call `abort`.
|
||||
Block *failureBlock = rewriter.createBlock(opBlock->getParent());
|
||||
rewriter.create<LLVM::CallOp>(loc, abortFunc, llvm::None);
|
||||
rewriter.create<LLVM::UnreachableOp>(loc);
|
||||
|
||||
// Generate assertion test.
|
||||
rewriter.setInsertionPointToEnd(opBlock);
|
||||
rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
|
||||
op, adaptor.getArg(), continuationBlock, failureBlock);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ConstantOpLowering : public ConvertOpToLLVMPattern<ConstantOp> {
|
||||
using ConvertOpToLLVMPattern<ConstantOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
|
@ -550,22 +509,6 @@ struct UnrealizedConversionCastOpLowering
|
|||
}
|
||||
};
|
||||
|
||||
// Base class for LLVM IR lowering terminator operations with successors.
|
||||
template <typename SourceOp, typename TargetOp>
|
||||
struct OneToOneLLVMTerminatorLowering
|
||||
: public ConvertOpToLLVMPattern<SourceOp> {
|
||||
using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
|
||||
using Super = OneToOneLLVMTerminatorLowering<SourceOp, TargetOp>;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<TargetOp>(op, adaptor.getOperands(),
|
||||
op->getSuccessors(), op->getAttrs());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Special lowering pattern for `ReturnOps`. Unlike all other operations,
|
||||
// `ReturnOp` interacts with the function signature and must have as many
|
||||
// operands as the function has return values. Because in LLVM IR, functions
|
||||
|
@ -633,21 +576,6 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<ReturnOp> {
|
|||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// FIXME: this should be tablegen'ed as well.
|
||||
struct BranchOpLowering
|
||||
: public OneToOneLLVMTerminatorLowering<BranchOp, LLVM::BrOp> {
|
||||
using Super::Super;
|
||||
};
|
||||
struct CondBranchOpLowering
|
||||
: public OneToOneLLVMTerminatorLowering<CondBranchOp, LLVM::CondBrOp> {
|
||||
using Super::Super;
|
||||
};
|
||||
struct SwitchOpLowering
|
||||
: public OneToOneLLVMTerminatorLowering<SwitchOp, LLVM::SwitchOp> {
|
||||
using Super::Super;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void mlir::populateStdToLLVMFuncOpConversionPattern(
|
||||
|
@ -663,14 +591,10 @@ void mlir::populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter,
|
|||
populateStdToLLVMFuncOpConversionPattern(converter, patterns);
|
||||
// clang-format off
|
||||
patterns.add<
|
||||
AssertOpLowering,
|
||||
BranchOpLowering,
|
||||
CallIndirectOpLowering,
|
||||
CallOpLowering,
|
||||
CondBranchOpLowering,
|
||||
ConstantOpLowering,
|
||||
ReturnOpLowering,
|
||||
SwitchOpLowering>(converter);
|
||||
ReturnOpLowering>(converter);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
|
@ -721,6 +645,7 @@ struct LLVMLoweringPass : public ConvertStandardToLLVMBase<LLVMLoweringPass> {
|
|||
RewritePatternSet patterns(&getContext());
|
||||
populateStdToLLVMConversionPatterns(typeConverter, patterns);
|
||||
arith::populateArithmeticToLLVMConversionPatterns(typeConverter, patterns);
|
||||
cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns);
|
||||
|
||||
LLVMConversionTarget target(getContext());
|
||||
if (failed(applyPartialConversion(m, target, std::move(patterns))))
|
||||
|
|
|
@ -11,6 +11,7 @@ add_mlir_conversion_library(MLIRStandardToSPIRV
|
|||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRArithmeticToSPIRV
|
||||
MLIRControlFlowToSPIRV
|
||||
MLIRIR
|
||||
MLIRMathToSPIRV
|
||||
MLIRMemRef
|
||||
|
|
|
@ -46,24 +46,6 @@ public:
|
|||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
/// Converts std.br to spv.Branch.
|
||||
struct BranchOpPattern final : public OpConversionPattern<BranchOp> {
|
||||
using OpConversionPattern<BranchOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(BranchOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
/// Converts std.cond_br to spv.BranchConditional.
|
||||
struct CondBranchOpPattern final : public OpConversionPattern<CondBranchOp> {
|
||||
using OpConversionPattern<CondBranchOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(CondBranchOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
/// Converts tensor.extract into loading using access chains from SPIR-V local
|
||||
/// variables.
|
||||
class TensorExtractPattern final
|
||||
|
@ -146,31 +128,6 @@ ReturnOpPattern::matchAndRewrite(ReturnOp returnOp, OpAdaptor adaptor,
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BranchOpPattern
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult
|
||||
BranchOpPattern::matchAndRewrite(BranchOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
rewriter.replaceOpWithNewOp<spirv::BranchOp>(op, op.getDest(),
|
||||
adaptor.getDestOperands());
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CondBranchOpPattern
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult CondBranchOpPattern::matchAndRewrite(
|
||||
CondBranchOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
rewriter.replaceOpWithNewOp<spirv::BranchConditionalOp>(
|
||||
op, op.getCondition(), op.getTrueDest(), adaptor.getTrueDestOperands(),
|
||||
op.getFalseDest(), adaptor.getFalseDestOperands());
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Pattern population
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -189,8 +146,7 @@ void populateStandardToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
|
|||
spirv::ElementwiseOpPattern<arith::MinSIOp, spirv::GLSLSMinOp>,
|
||||
spirv::ElementwiseOpPattern<arith::MinUIOp, spirv::GLSLUMinOp>,
|
||||
|
||||
ReturnOpPattern, BranchOpPattern, CondBranchOpPattern>(typeConverter,
|
||||
context);
|
||||
ReturnOpPattern>(typeConverter, context);
|
||||
}
|
||||
|
||||
void populateTensorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
#include "mlir/Conversion/StandardToSPIRV/StandardToSPIRVPass.h"
|
||||
#include "../PassDetail.h"
|
||||
#include "mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h"
|
||||
#include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h"
|
||||
#include "mlir/Conversion/MathToSPIRV/MathToSPIRV.h"
|
||||
#include "mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h"
|
||||
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
|
||||
|
@ -40,9 +41,11 @@ void ConvertStandardToSPIRVPass::runOnOperation() {
|
|||
options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes;
|
||||
SPIRVTypeConverter typeConverter(targetAttr, options);
|
||||
|
||||
// TODO ArithmeticToSPIRV cannot be applied separately to StandardToSPIRV
|
||||
// TODO ArithmeticToSPIRV/ControlFlowToSPIRV cannot be applied separately to
|
||||
// StandardToSPIRV
|
||||
RewritePatternSet patterns(context);
|
||||
arith::populateArithmeticToSPIRVPatterns(typeConverter, patterns);
|
||||
cf::populateControlFlowToSPIRVPatterns(typeConverter, patterns);
|
||||
populateMathToSPIRVPatterns(typeConverter, patterns);
|
||||
populateStandardToSPIRVPatterns(typeConverter, patterns);
|
||||
populateTensorToSPIRVPatterns(typeConverter, /*byteCountThreshold=*/64,
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
#include "mlir/Analysis/Liveness.h"
|
||||
#include "mlir/Dialect/Async/IR/Async.h"
|
||||
#include "mlir/Dialect/Async/Passes.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/ImplicitLocOpBuilder.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
@ -169,11 +170,11 @@ private:
|
|||
///
|
||||
/// ^entry:
|
||||
/// %token = async.runtime.create : !async.token
|
||||
/// cond_br %cond, ^bb1, ^bb2
|
||||
/// cf.cond_br %cond, ^bb1, ^bb2
|
||||
/// ^bb1:
|
||||
/// async.runtime.await %token
|
||||
/// async.runtime.drop_ref %token
|
||||
/// br ^bb2
|
||||
/// cf.br ^bb2
|
||||
/// ^bb2:
|
||||
/// return
|
||||
///
|
||||
|
@ -185,14 +186,14 @@ private:
|
|||
///
|
||||
/// ^entry:
|
||||
/// %token = async.runtime.create : !async.token
|
||||
/// cond_br %cond, ^bb1, ^reference_counting
|
||||
/// cf.cond_br %cond, ^bb1, ^reference_counting
|
||||
/// ^bb1:
|
||||
/// async.runtime.await %token
|
||||
/// async.runtime.drop_ref %token
|
||||
/// br ^bb2
|
||||
/// cf.br ^bb2
|
||||
/// ^reference_counting:
|
||||
/// async.runtime.drop_ref %token
|
||||
/// br ^bb2
|
||||
/// cf.br ^bb2
|
||||
/// ^bb2:
|
||||
/// return
|
||||
///
|
||||
|
@ -208,7 +209,7 @@ private:
|
|||
/// async.coro.suspend %ret, ^suspend, ^resume, ^cleanup
|
||||
/// ^resume:
|
||||
/// %0 = async.runtime.load %value
|
||||
/// br ^cleanup
|
||||
/// cf.br ^cleanup
|
||||
/// ^cleanup:
|
||||
/// ...
|
||||
/// ^suspend:
|
||||
|
@ -406,7 +407,7 @@ AsyncRuntimeRefCountingPass::addDropRefInDivergentLivenessSuccessor(
|
|||
refCountingBlock = &successor->getParent()->emplaceBlock();
|
||||
refCountingBlock->moveBefore(successor);
|
||||
OpBuilder builder = OpBuilder::atBlockEnd(refCountingBlock);
|
||||
builder.create<BranchOp>(value.getLoc(), successor);
|
||||
builder.create<cf::BranchOp>(value.getLoc(), successor);
|
||||
}
|
||||
|
||||
OpBuilder builder = OpBuilder::atBlockBegin(refCountingBlock);
|
||||
|
|
|
@ -12,10 +12,11 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "PassDetail.h"
|
||||
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
|
||||
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Async/IR/Async.h"
|
||||
#include "mlir/Dialect/Async/Passes.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
|
@ -105,18 +106,18 @@ struct CoroMachinery {
|
|||
/// %value = <async value> : !async.value<T> // create async value
|
||||
/// %id = async.coro.id // create a coroutine id
|
||||
/// %hdl = async.coro.begin %id // create a coroutine handle
|
||||
/// br ^preexisting_entry_block
|
||||
/// cf.br ^preexisting_entry_block
|
||||
///
|
||||
/// /* preexisting blocks modified to branch to the cleanup block */
|
||||
///
|
||||
/// ^set_error: // this block created lazily only if needed (see code below)
|
||||
/// async.runtime.set_error %token : !async.token
|
||||
/// async.runtime.set_error %value : !async.value<T>
|
||||
/// br ^cleanup
|
||||
/// cf.br ^cleanup
|
||||
///
|
||||
/// ^cleanup:
|
||||
/// async.coro.free %hdl // delete the coroutine state
|
||||
/// br ^suspend
|
||||
/// cf.br ^suspend
|
||||
///
|
||||
/// ^suspend:
|
||||
/// async.coro.end %hdl // marks the end of a coroutine
|
||||
|
@ -147,7 +148,7 @@ static CoroMachinery setupCoroMachinery(FuncOp func) {
|
|||
auto coroIdOp = builder.create<CoroIdOp>(CoroIdType::get(ctx));
|
||||
auto coroHdlOp =
|
||||
builder.create<CoroBeginOp>(CoroHandleType::get(ctx), coroIdOp.id());
|
||||
builder.create<BranchOp>(originalEntryBlock);
|
||||
builder.create<cf::BranchOp>(originalEntryBlock);
|
||||
|
||||
Block *cleanupBlock = func.addBlock();
|
||||
Block *suspendBlock = func.addBlock();
|
||||
|
@ -159,7 +160,7 @@ static CoroMachinery setupCoroMachinery(FuncOp func) {
|
|||
builder.create<CoroFreeOp>(coroIdOp.id(), coroHdlOp.handle());
|
||||
|
||||
// Branch into the suspend block.
|
||||
builder.create<BranchOp>(suspendBlock);
|
||||
builder.create<cf::BranchOp>(suspendBlock);
|
||||
|
||||
// ------------------------------------------------------------------------ //
|
||||
// Coroutine suspend block: mark the end of a coroutine and return allocated
|
||||
|
@ -186,7 +187,7 @@ static CoroMachinery setupCoroMachinery(FuncOp func) {
|
|||
Operation *terminator = block.getTerminator();
|
||||
if (auto yield = dyn_cast<YieldOp>(terminator)) {
|
||||
builder.setInsertionPointToEnd(&block);
|
||||
builder.create<BranchOp>(cleanupBlock);
|
||||
builder.create<cf::BranchOp>(cleanupBlock);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -227,7 +228,7 @@ static Block *setupSetErrorBlock(CoroMachinery &coro) {
|
|||
builder.create<RuntimeSetErrorOp>(retValue);
|
||||
|
||||
// Branch into the cleanup block.
|
||||
builder.create<BranchOp>(coro.cleanup);
|
||||
builder.create<cf::BranchOp>(coro.cleanup);
|
||||
|
||||
return coro.setError;
|
||||
}
|
||||
|
@ -305,7 +306,7 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
|
|||
// Async resume operation (execution will be resumed in a thread managed by
|
||||
// the async runtime).
|
||||
{
|
||||
BranchOp branch = cast<BranchOp>(coro.entry->getTerminator());
|
||||
cf::BranchOp branch = cast<cf::BranchOp>(coro.entry->getTerminator());
|
||||
builder.setInsertionPointToEnd(coro.entry);
|
||||
|
||||
// Save the coroutine state: async.coro.save
|
||||
|
@ -419,8 +420,8 @@ public:
|
|||
isError, builder.create<arith::ConstantOp>(
|
||||
loc, i1, builder.getIntegerAttr(i1, 1)));
|
||||
|
||||
builder.create<AssertOp>(notError,
|
||||
"Awaited async operand is in error state");
|
||||
builder.create<cf::AssertOp>(notError,
|
||||
"Awaited async operand is in error state");
|
||||
}
|
||||
|
||||
// Inside the coroutine we convert await operation into coroutine suspension
|
||||
|
@ -452,11 +453,11 @@ public:
|
|||
// Check if the awaited value is in the error state.
|
||||
builder.setInsertionPointToStart(resume);
|
||||
auto isError = builder.create<RuntimeIsErrorOp>(loc, i1, operand);
|
||||
builder.create<CondBranchOp>(isError,
|
||||
/*trueDest=*/setupSetErrorBlock(coro),
|
||||
/*trueArgs=*/ArrayRef<Value>(),
|
||||
/*falseDest=*/continuation,
|
||||
/*falseArgs=*/ArrayRef<Value>());
|
||||
builder.create<cf::CondBranchOp>(isError,
|
||||
/*trueDest=*/setupSetErrorBlock(coro),
|
||||
/*trueArgs=*/ArrayRef<Value>(),
|
||||
/*falseDest=*/continuation,
|
||||
/*falseArgs=*/ArrayRef<Value>());
|
||||
|
||||
// Make sure that replacement value will be constructed in the
|
||||
// continuation block.
|
||||
|
@ -560,18 +561,18 @@ private:
|
|||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Convert std.assert operation to cond_br into `set_error` block.
|
||||
// Convert std.assert operation to cf.cond_br into `set_error` block.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class AssertOpLowering : public OpConversionPattern<AssertOp> {
|
||||
class AssertOpLowering : public OpConversionPattern<cf::AssertOp> {
|
||||
public:
|
||||
AssertOpLowering(MLIRContext *ctx,
|
||||
llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
|
||||
: OpConversionPattern<AssertOp>(ctx),
|
||||
: OpConversionPattern<cf::AssertOp>(ctx),
|
||||
outlinedFunctions(outlinedFunctions) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(AssertOp op, OpAdaptor adaptor,
|
||||
matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// Check if assert operation is inside the async coroutine function.
|
||||
auto func = op->template getParentOfType<FuncOp>();
|
||||
|
@ -585,11 +586,11 @@ public:
|
|||
|
||||
Block *cont = rewriter.splitBlock(op->getBlock(), Block::iterator(op));
|
||||
rewriter.setInsertionPointToEnd(cont->getPrevNode());
|
||||
rewriter.create<CondBranchOp>(loc, adaptor.getArg(),
|
||||
/*trueDest=*/cont,
|
||||
/*trueArgs=*/ArrayRef<Value>(),
|
||||
/*falseDest=*/setupSetErrorBlock(coro),
|
||||
/*falseArgs=*/ArrayRef<Value>());
|
||||
rewriter.create<cf::CondBranchOp>(loc, adaptor.getArg(),
|
||||
/*trueDest=*/cont,
|
||||
/*trueArgs=*/ArrayRef<Value>(),
|
||||
/*falseDest=*/setupSetErrorBlock(coro),
|
||||
/*falseArgs=*/ArrayRef<Value>());
|
||||
rewriter.eraseOp(op);
|
||||
|
||||
return success();
|
||||
|
@ -765,7 +766,7 @@ void AsyncToAsyncRuntimePass::runOnOperation() {
|
|||
// and we have to make sure that structured control flow operations with async
|
||||
// operations in nested regions will be converted to branch-based control flow
|
||||
// before we add the coroutine basic blocks.
|
||||
populateLoopToStdConversionPatterns(asyncPatterns);
|
||||
populateSCFToControlFlowConversionPatterns(asyncPatterns);
|
||||
|
||||
// Async lowering does not use type converter because it must preserve all
|
||||
// types for async.runtime operations.
|
||||
|
@ -792,14 +793,15 @@ void AsyncToAsyncRuntimePass::runOnOperation() {
|
|||
});
|
||||
return !walkResult.wasInterrupted();
|
||||
});
|
||||
runtimeTarget.addLegalOp<AssertOp, arith::XOrIOp, arith::ConstantOp,
|
||||
ConstantOp, BranchOp, CondBranchOp>();
|
||||
runtimeTarget.addLegalOp<cf::AssertOp, arith::XOrIOp, arith::ConstantOp,
|
||||
ConstantOp, cf::BranchOp, cf::CondBranchOp>();
|
||||
|
||||
// Assertions must be converted to runtime errors inside async functions.
|
||||
runtimeTarget.addDynamicallyLegalOp<AssertOp>([&](AssertOp op) -> bool {
|
||||
auto func = op->getParentOfType<FuncOp>();
|
||||
return outlinedFunctions.find(func) == outlinedFunctions.end();
|
||||
});
|
||||
runtimeTarget.addDynamicallyLegalOp<cf::AssertOp>(
|
||||
[&](cf::AssertOp op) -> bool {
|
||||
auto func = op->getParentOfType<FuncOp>();
|
||||
return outlinedFunctions.find(func) == outlinedFunctions.end();
|
||||
});
|
||||
|
||||
if (eliminateBlockingAwaitOps)
|
||||
runtimeTarget.addDynamicallyLegalOp<RuntimeAwaitOp>(
|
||||
|
|
|
@ -17,7 +17,7 @@ add_mlir_dialect_library(MLIRAsyncTransforms
|
|||
MLIRIR
|
||||
MLIRPass
|
||||
MLIRSCF
|
||||
MLIRSCFToStandard
|
||||
MLIRSCFToControlFlow
|
||||
MLIRStandard
|
||||
MLIRTransforms
|
||||
MLIRTransformUtils
|
||||
|
|
|
@ -18,12 +18,12 @@
|
|||
// (using the BufferViewFlowAnalysis class). Consider the following example:
|
||||
//
|
||||
// ^bb0(%arg0):
|
||||
// cond_br %cond, ^bb1, ^bb2
|
||||
// cf.cond_br %cond, ^bb1, ^bb2
|
||||
// ^bb1:
|
||||
// br ^exit(%arg0)
|
||||
// cf.br ^exit(%arg0)
|
||||
// ^bb2:
|
||||
// %new_value = ...
|
||||
// br ^exit(%new_value)
|
||||
// cf.br ^exit(%new_value)
|
||||
// ^exit(%arg1):
|
||||
// return %arg1;
|
||||
//
|
||||
|
|
|
@ -6,6 +6,7 @@ add_subdirectory(Async)
|
|||
add_subdirectory(AMX)
|
||||
add_subdirectory(Bufferization)
|
||||
add_subdirectory(Complex)
|
||||
add_subdirectory(ControlFlow)
|
||||
add_subdirectory(DLTI)
|
||||
add_subdirectory(EmitC)
|
||||
add_subdirectory(GPU)
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
add_subdirectory(IR)
|
|
@ -0,0 +1,15 @@
|
|||
add_mlir_dialect_library(MLIRControlFlow
|
||||
ControlFlowOps.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ControlFlow/IR
|
||||
|
||||
DEPENDS
|
||||
MLIRControlFlowOpsIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRArithmetic
|
||||
MLIRControlFlowInterfaces
|
||||
MLIRIR
|
||||
MLIRSideEffectInterfaces
|
||||
)
|
|
@ -0,0 +1,891 @@
|
|||
//===- ControlFlowOps.cpp - ControlFlow Operations ------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/CommonFolders.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Support/MathExtras.h"
|
||||
#include "mlir/Transforms/InliningUtils.h"
|
||||
#include "llvm/ADT/APFloat.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include <numeric>
|
||||
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOpsDialect.cpp.inc"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::cf;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ControlFlowDialect Interfaces
|
||||
//===----------------------------------------------------------------------===//
|
||||
namespace {
|
||||
/// This class defines the interface for handling inlining with control flow
|
||||
/// operations.
|
||||
struct ControlFlowInlinerInterface : public DialectInlinerInterface {
|
||||
using DialectInlinerInterface::DialectInlinerInterface;
|
||||
~ControlFlowInlinerInterface() override = default;
|
||||
|
||||
/// All control flow operations can be inlined.
|
||||
bool isLegalToInline(Operation *call, Operation *callable,
|
||||
bool wouldBeCloned) const final {
|
||||
return true;
|
||||
}
|
||||
bool isLegalToInline(Operation *, Region *, bool,
|
||||
BlockAndValueMapping &) const final {
|
||||
return true;
|
||||
}
|
||||
|
||||
/// ControlFlow terminator operations don't really need any special handing.
|
||||
void handleTerminator(Operation *op, Block *newDest) const final {}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ControlFlowDialect
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void ControlFlowDialect::initialize() {
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
|
||||
>();
|
||||
addInterfaces<ControlFlowInlinerInterface>();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AssertOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) {
|
||||
// Erase assertion if argument is constant true.
|
||||
if (matchPattern(op.getArg(), m_One())) {
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BranchOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Given a successor, try to collapse it to a new destination if it only
|
||||
/// contains a passthrough unconditional branch. If the successor is
|
||||
/// collapsable, `successor` and `successorOperands` are updated to reference
|
||||
/// the new destination and values. `argStorage` is used as storage if operands
|
||||
/// to the collapsed successor need to be remapped. It must outlive uses of
|
||||
/// successorOperands.
|
||||
static LogicalResult collapseBranch(Block *&successor,
|
||||
ValueRange &successorOperands,
|
||||
SmallVectorImpl<Value> &argStorage) {
|
||||
// Check that the successor only contains a unconditional branch.
|
||||
if (std::next(successor->begin()) != successor->end())
|
||||
return failure();
|
||||
// Check that the terminator is an unconditional branch.
|
||||
BranchOp successorBranch = dyn_cast<BranchOp>(successor->getTerminator());
|
||||
if (!successorBranch)
|
||||
return failure();
|
||||
// Check that the arguments are only used within the terminator.
|
||||
for (BlockArgument arg : successor->getArguments()) {
|
||||
for (Operation *user : arg.getUsers())
|
||||
if (user != successorBranch)
|
||||
return failure();
|
||||
}
|
||||
// Don't try to collapse branches to infinite loops.
|
||||
Block *successorDest = successorBranch.getDest();
|
||||
if (successorDest == successor)
|
||||
return failure();
|
||||
|
||||
// Update the operands to the successor. If the branch parent has no
|
||||
// arguments, we can use the branch operands directly.
|
||||
OperandRange operands = successorBranch.getOperands();
|
||||
if (successor->args_empty()) {
|
||||
successor = successorDest;
|
||||
successorOperands = operands;
|
||||
return success();
|
||||
}
|
||||
|
||||
// Otherwise, we need to remap any argument operands.
|
||||
for (Value operand : operands) {
|
||||
BlockArgument argOperand = operand.dyn_cast<BlockArgument>();
|
||||
if (argOperand && argOperand.getOwner() == successor)
|
||||
argStorage.push_back(successorOperands[argOperand.getArgNumber()]);
|
||||
else
|
||||
argStorage.push_back(operand);
|
||||
}
|
||||
successor = successorDest;
|
||||
successorOperands = argStorage;
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Simplify a branch to a block that has a single predecessor. This effectively
|
||||
/// merges the two blocks.
|
||||
static LogicalResult
|
||||
simplifyBrToBlockWithSinglePred(BranchOp op, PatternRewriter &rewriter) {
|
||||
// Check that the successor block has a single predecessor.
|
||||
Block *succ = op.getDest();
|
||||
Block *opParent = op->getBlock();
|
||||
if (succ == opParent || !llvm::hasSingleElement(succ->getPredecessors()))
|
||||
return failure();
|
||||
|
||||
// Merge the successor into the current block and erase the branch.
|
||||
rewriter.mergeBlocks(succ, opParent, op.getOperands());
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
|
||||
/// br ^bb1
|
||||
/// ^bb1
|
||||
/// br ^bbN(...)
|
||||
///
|
||||
/// -> br ^bbN(...)
|
||||
///
|
||||
static LogicalResult simplifyPassThroughBr(BranchOp op,
|
||||
PatternRewriter &rewriter) {
|
||||
Block *dest = op.getDest();
|
||||
ValueRange destOperands = op.getOperands();
|
||||
SmallVector<Value, 4> destOperandStorage;
|
||||
|
||||
// Try to collapse the successor if it points somewhere other than this
|
||||
// block.
|
||||
if (dest == op->getBlock() ||
|
||||
failed(collapseBranch(dest, destOperands, destOperandStorage)))
|
||||
return failure();
|
||||
|
||||
// Create a new branch with the collapsed successor.
|
||||
rewriter.replaceOpWithNewOp<BranchOp>(op, dest, destOperands);
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult BranchOp::canonicalize(BranchOp op, PatternRewriter &rewriter) {
|
||||
return success(succeeded(simplifyBrToBlockWithSinglePred(op, rewriter)) ||
|
||||
succeeded(simplifyPassThroughBr(op, rewriter)));
|
||||
}
|
||||
|
||||
void BranchOp::setDest(Block *block) { return setSuccessor(block); }
|
||||
|
||||
void BranchOp::eraseOperand(unsigned index) { (*this)->eraseOperand(index); }
|
||||
|
||||
Optional<MutableOperandRange>
|
||||
BranchOp::getMutableSuccessorOperands(unsigned index) {
|
||||
assert(index == 0 && "invalid successor index");
|
||||
return getDestOperandsMutable();
|
||||
}
|
||||
|
||||
Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) {
|
||||
return getDest();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CondBranchOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
/// cf.cond_br true, ^bb1, ^bb2
|
||||
/// -> br ^bb1
|
||||
/// cf.cond_br false, ^bb1, ^bb2
|
||||
/// -> br ^bb2
|
||||
///
|
||||
struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
|
||||
using OpRewritePattern<CondBranchOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(CondBranchOp condbr,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (matchPattern(condbr.getCondition(), m_NonZero())) {
|
||||
// True branch taken.
|
||||
rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
|
||||
condbr.getTrueOperands());
|
||||
return success();
|
||||
}
|
||||
if (matchPattern(condbr.getCondition(), m_Zero())) {
|
||||
// False branch taken.
|
||||
rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
|
||||
condbr.getFalseOperands());
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
/// cf.cond_br %cond, ^bb1, ^bb2
|
||||
/// ^bb1
|
||||
/// br ^bbN(...)
|
||||
/// ^bb2
|
||||
/// br ^bbK(...)
|
||||
///
|
||||
/// -> cf.cond_br %cond, ^bbN(...), ^bbK(...)
|
||||
///
|
||||
struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> {
|
||||
using OpRewritePattern<CondBranchOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(CondBranchOp condbr,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Block *trueDest = condbr.getTrueDest(), *falseDest = condbr.getFalseDest();
|
||||
ValueRange trueDestOperands = condbr.getTrueOperands();
|
||||
ValueRange falseDestOperands = condbr.getFalseOperands();
|
||||
SmallVector<Value, 4> trueDestOperandStorage, falseDestOperandStorage;
|
||||
|
||||
// Try to collapse one of the current successors.
|
||||
LogicalResult collapsedTrue =
|
||||
collapseBranch(trueDest, trueDestOperands, trueDestOperandStorage);
|
||||
LogicalResult collapsedFalse =
|
||||
collapseBranch(falseDest, falseDestOperands, falseDestOperandStorage);
|
||||
if (failed(collapsedTrue) && failed(collapsedFalse))
|
||||
return failure();
|
||||
|
||||
// Create a new branch with the collapsed successors.
|
||||
rewriter.replaceOpWithNewOp<CondBranchOp>(condbr, condbr.getCondition(),
|
||||
trueDest, trueDestOperands,
|
||||
falseDest, falseDestOperands);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// cf.cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N)
|
||||
/// -> br ^bb1(A, ..., N)
|
||||
///
|
||||
/// cf.cond_br %cond, ^bb1(A), ^bb1(B)
|
||||
/// -> %select = arith.select %cond, A, B
|
||||
/// br ^bb1(%select)
|
||||
///
|
||||
struct SimplifyCondBranchIdenticalSuccessors
|
||||
: public OpRewritePattern<CondBranchOp> {
|
||||
using OpRewritePattern<CondBranchOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(CondBranchOp condbr,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Check that the true and false destinations are the same and have the same
|
||||
// operands.
|
||||
Block *trueDest = condbr.getTrueDest();
|
||||
if (trueDest != condbr.getFalseDest())
|
||||
return failure();
|
||||
|
||||
// If all of the operands match, no selects need to be generated.
|
||||
OperandRange trueOperands = condbr.getTrueOperands();
|
||||
OperandRange falseOperands = condbr.getFalseOperands();
|
||||
if (trueOperands == falseOperands) {
|
||||
rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, trueOperands);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Otherwise, if the current block is the only predecessor insert selects
|
||||
// for any mismatched branch operands.
|
||||
if (trueDest->getUniquePredecessor() != condbr->getBlock())
|
||||
return failure();
|
||||
|
||||
// Generate a select for any operands that differ between the two.
|
||||
SmallVector<Value, 8> mergedOperands;
|
||||
mergedOperands.reserve(trueOperands.size());
|
||||
Value condition = condbr.getCondition();
|
||||
for (auto it : llvm::zip(trueOperands, falseOperands)) {
|
||||
if (std::get<0>(it) == std::get<1>(it))
|
||||
mergedOperands.push_back(std::get<0>(it));
|
||||
else
|
||||
mergedOperands.push_back(rewriter.create<arith::SelectOp>(
|
||||
condbr.getLoc(), condition, std::get<0>(it), std::get<1>(it)));
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, mergedOperands);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// ...
|
||||
/// cf.cond_br %cond, ^bb1(...), ^bb2(...)
|
||||
/// ...
|
||||
/// ^bb1: // has single predecessor
|
||||
/// ...
|
||||
/// cf.cond_br %cond, ^bb3(...), ^bb4(...)
|
||||
///
|
||||
/// ->
|
||||
///
|
||||
/// ...
|
||||
/// cf.cond_br %cond, ^bb1(...), ^bb2(...)
|
||||
/// ...
|
||||
/// ^bb1: // has single predecessor
|
||||
/// ...
|
||||
/// br ^bb3(...)
|
||||
///
|
||||
struct SimplifyCondBranchFromCondBranchOnSameCondition
|
||||
: public OpRewritePattern<CondBranchOp> {
|
||||
using OpRewritePattern<CondBranchOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(CondBranchOp condbr,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Check that we have a single distinct predecessor.
|
||||
Block *currentBlock = condbr->getBlock();
|
||||
Block *predecessor = currentBlock->getSinglePredecessor();
|
||||
if (!predecessor)
|
||||
return failure();
|
||||
|
||||
// Check that the predecessor terminates with a conditional branch to this
|
||||
// block and that it branches on the same condition.
|
||||
auto predBranch = dyn_cast<CondBranchOp>(predecessor->getTerminator());
|
||||
if (!predBranch || condbr.getCondition() != predBranch.getCondition())
|
||||
return failure();
|
||||
|
||||
// Fold this branch to an unconditional branch.
|
||||
if (currentBlock == predBranch.getTrueDest())
|
||||
rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
|
||||
condbr.getTrueDestOperands());
|
||||
else
|
||||
rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
|
||||
condbr.getFalseDestOperands());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// cf.cond_br %arg0, ^trueB, ^falseB
|
||||
///
|
||||
/// ^trueB:
|
||||
/// "test.consumer1"(%arg0) : (i1) -> ()
|
||||
/// ...
|
||||
///
|
||||
/// ^falseB:
|
||||
/// "test.consumer2"(%arg0) : (i1) -> ()
|
||||
/// ...
|
||||
///
|
||||
/// ->
|
||||
///
|
||||
/// cf.cond_br %arg0, ^trueB, ^falseB
|
||||
/// ^trueB:
|
||||
/// "test.consumer1"(%true) : (i1) -> ()
|
||||
/// ...
|
||||
///
|
||||
/// ^falseB:
|
||||
/// "test.consumer2"(%false) : (i1) -> ()
|
||||
/// ...
|
||||
struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> {
|
||||
using OpRewritePattern<CondBranchOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(CondBranchOp condbr,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Check that we have a single distinct predecessor.
|
||||
bool replaced = false;
|
||||
Type ty = rewriter.getI1Type();
|
||||
|
||||
// These variables serve to prevent creating duplicate constants
|
||||
// and hold constant true or false values.
|
||||
Value constantTrue = nullptr;
|
||||
Value constantFalse = nullptr;
|
||||
|
||||
// TODO These checks can be expanded to encompas any use with only
|
||||
// either the true of false edge as a predecessor. For now, we fall
|
||||
// back to checking the single predecessor is given by the true/fasle
|
||||
// destination, thereby ensuring that only that edge can reach the
|
||||
// op.
|
||||
if (condbr.getTrueDest()->getSinglePredecessor()) {
|
||||
for (OpOperand &use :
|
||||
llvm::make_early_inc_range(condbr.getCondition().getUses())) {
|
||||
if (use.getOwner()->getBlock() == condbr.getTrueDest()) {
|
||||
replaced = true;
|
||||
|
||||
if (!constantTrue)
|
||||
constantTrue = rewriter.create<arith::ConstantOp>(
|
||||
condbr.getLoc(), ty, rewriter.getBoolAttr(true));
|
||||
|
||||
rewriter.updateRootInPlace(use.getOwner(),
|
||||
[&] { use.set(constantTrue); });
|
||||
}
|
||||
}
|
||||
}
|
||||
if (condbr.getFalseDest()->getSinglePredecessor()) {
|
||||
for (OpOperand &use :
|
||||
llvm::make_early_inc_range(condbr.getCondition().getUses())) {
|
||||
if (use.getOwner()->getBlock() == condbr.getFalseDest()) {
|
||||
replaced = true;
|
||||
|
||||
if (!constantFalse)
|
||||
constantFalse = rewriter.create<arith::ConstantOp>(
|
||||
condbr.getLoc(), ty, rewriter.getBoolAttr(false));
|
||||
|
||||
rewriter.updateRootInPlace(use.getOwner(),
|
||||
[&] { use.set(constantFalse); });
|
||||
}
|
||||
}
|
||||
}
|
||||
return success(replaced);
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
MLIRContext *context) {
|
||||
results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
|
||||
SimplifyCondBranchIdenticalSuccessors,
|
||||
SimplifyCondBranchFromCondBranchOnSameCondition,
|
||||
CondBranchTruthPropagation>(context);
|
||||
}
|
||||
|
||||
Optional<MutableOperandRange>
|
||||
CondBranchOp::getMutableSuccessorOperands(unsigned index) {
|
||||
assert(index < getNumSuccessors() && "invalid successor index");
|
||||
return index == trueIndex ? getTrueDestOperandsMutable()
|
||||
: getFalseDestOperandsMutable();
|
||||
}
|
||||
|
||||
Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
|
||||
if (IntegerAttr condAttr = operands.front().dyn_cast_or_null<IntegerAttr>())
|
||||
return condAttr.getValue().isOneValue() ? getTrueDest() : getFalseDest();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SwitchOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
|
||||
Block *defaultDestination, ValueRange defaultOperands,
|
||||
DenseIntElementsAttr caseValues,
|
||||
BlockRange caseDestinations,
|
||||
ArrayRef<ValueRange> caseOperands) {
|
||||
build(builder, result, value, defaultOperands, caseOperands, caseValues,
|
||||
defaultDestination, caseDestinations);
|
||||
}
|
||||
|
||||
void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
|
||||
Block *defaultDestination, ValueRange defaultOperands,
|
||||
ArrayRef<APInt> caseValues, BlockRange caseDestinations,
|
||||
ArrayRef<ValueRange> caseOperands) {
|
||||
DenseIntElementsAttr caseValuesAttr;
|
||||
if (!caseValues.empty()) {
|
||||
ShapedType caseValueType = VectorType::get(
|
||||
static_cast<int64_t>(caseValues.size()), value.getType());
|
||||
caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
|
||||
}
|
||||
build(builder, result, value, defaultDestination, defaultOperands,
|
||||
caseValuesAttr, caseDestinations, caseOperands);
|
||||
}
|
||||
|
||||
/// <cases> ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)?
|
||||
/// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )*
|
||||
static ParseResult parseSwitchOpCases(
|
||||
OpAsmParser &parser, Type &flagType, Block *&defaultDestination,
|
||||
SmallVectorImpl<OpAsmParser::OperandType> &defaultOperands,
|
||||
SmallVectorImpl<Type> &defaultOperandTypes,
|
||||
DenseIntElementsAttr &caseValues,
|
||||
SmallVectorImpl<Block *> &caseDestinations,
|
||||
SmallVectorImpl<SmallVector<OpAsmParser::OperandType>> &caseOperands,
|
||||
SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) {
|
||||
if (parser.parseKeyword("default") || parser.parseColon() ||
|
||||
parser.parseSuccessor(defaultDestination))
|
||||
return failure();
|
||||
if (succeeded(parser.parseOptionalLParen())) {
|
||||
if (parser.parseRegionArgumentList(defaultOperands) ||
|
||||
parser.parseColonTypeList(defaultOperandTypes) || parser.parseRParen())
|
||||
return failure();
|
||||
}
|
||||
|
||||
SmallVector<APInt> values;
|
||||
unsigned bitWidth = flagType.getIntOrFloatBitWidth();
|
||||
while (succeeded(parser.parseOptionalComma())) {
|
||||
int64_t value = 0;
|
||||
if (failed(parser.parseInteger(value)))
|
||||
return failure();
|
||||
values.push_back(APInt(bitWidth, value));
|
||||
|
||||
Block *destination;
|
||||
SmallVector<OpAsmParser::OperandType> operands;
|
||||
SmallVector<Type> operandTypes;
|
||||
if (failed(parser.parseColon()) ||
|
||||
failed(parser.parseSuccessor(destination)))
|
||||
return failure();
|
||||
if (succeeded(parser.parseOptionalLParen())) {
|
||||
if (failed(parser.parseRegionArgumentList(operands)) ||
|
||||
failed(parser.parseColonTypeList(operandTypes)) ||
|
||||
failed(parser.parseRParen()))
|
||||
return failure();
|
||||
}
|
||||
caseDestinations.push_back(destination);
|
||||
caseOperands.emplace_back(operands);
|
||||
caseOperandTypes.emplace_back(operandTypes);
|
||||
}
|
||||
|
||||
if (!values.empty()) {
|
||||
ShapedType caseValueType =
|
||||
VectorType::get(static_cast<int64_t>(values.size()), flagType);
|
||||
caseValues = DenseIntElementsAttr::get(caseValueType, values);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
static void printSwitchOpCases(
|
||||
OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination,
|
||||
OperandRange defaultOperands, TypeRange defaultOperandTypes,
|
||||
DenseIntElementsAttr caseValues, SuccessorRange caseDestinations,
|
||||
OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes) {
|
||||
p << " default: ";
|
||||
p.printSuccessorAndUseList(defaultDestination, defaultOperands);
|
||||
|
||||
if (!caseValues)
|
||||
return;
|
||||
|
||||
for (const auto &it : llvm::enumerate(caseValues.getValues<APInt>())) {
|
||||
p << ',';
|
||||
p.printNewline();
|
||||
p << " ";
|
||||
p << it.value().getLimitedValue();
|
||||
p << ": ";
|
||||
p.printSuccessorAndUseList(caseDestinations[it.index()],
|
||||
caseOperands[it.index()]);
|
||||
}
|
||||
p.printNewline();
|
||||
}
|
||||
|
||||
LogicalResult SwitchOp::verify() {
|
||||
auto caseValues = getCaseValues();
|
||||
auto caseDestinations = getCaseDestinations();
|
||||
|
||||
if (!caseValues && caseDestinations.empty())
|
||||
return success();
|
||||
|
||||
Type flagType = getFlag().getType();
|
||||
Type caseValueType = caseValues->getType().getElementType();
|
||||
if (caseValueType != flagType)
|
||||
return emitOpError() << "'flag' type (" << flagType
|
||||
<< ") should match case value type (" << caseValueType
|
||||
<< ")";
|
||||
|
||||
if (caseValues &&
|
||||
caseValues->size() != static_cast<int64_t>(caseDestinations.size()))
|
||||
return emitOpError() << "number of case values (" << caseValues->size()
|
||||
<< ") should match number of "
|
||||
"case destinations ("
|
||||
<< caseDestinations.size() << ")";
|
||||
return success();
|
||||
}
|
||||
|
||||
Optional<MutableOperandRange>
|
||||
SwitchOp::getMutableSuccessorOperands(unsigned index) {
|
||||
assert(index < getNumSuccessors() && "invalid successor index");
|
||||
return index == 0 ? getDefaultOperandsMutable()
|
||||
: getCaseOperandsMutable(index - 1);
|
||||
}
|
||||
|
||||
Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
|
||||
Optional<DenseIntElementsAttr> caseValues = getCaseValues();
|
||||
|
||||
if (!caseValues)
|
||||
return getDefaultDestination();
|
||||
|
||||
SuccessorRange caseDests = getCaseDestinations();
|
||||
if (auto value = operands.front().dyn_cast_or_null<IntegerAttr>()) {
|
||||
for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>()))
|
||||
if (it.value() == value.getValue())
|
||||
return caseDests[it.index()];
|
||||
return getDefaultDestination();
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
/// switch %flag : i32, [
|
||||
/// default: ^bb1
|
||||
/// ]
|
||||
/// -> br ^bb1
|
||||
static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op,
|
||||
PatternRewriter &rewriter) {
|
||||
if (!op.getCaseDestinations().empty())
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
|
||||
op.getDefaultOperands());
|
||||
return success();
|
||||
}
|
||||
|
||||
/// switch %flag : i32, [
|
||||
/// default: ^bb1,
|
||||
/// 42: ^bb1,
|
||||
/// 43: ^bb2
|
||||
/// ]
|
||||
/// ->
|
||||
/// switch %flag : i32, [
|
||||
/// default: ^bb1,
|
||||
/// 43: ^bb2
|
||||
/// ]
|
||||
static LogicalResult
|
||||
dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter) {
|
||||
SmallVector<Block *> newCaseDestinations;
|
||||
SmallVector<ValueRange> newCaseOperands;
|
||||
SmallVector<APInt> newCaseValues;
|
||||
bool requiresChange = false;
|
||||
auto caseValues = op.getCaseValues();
|
||||
auto caseDests = op.getCaseDestinations();
|
||||
|
||||
for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
|
||||
if (caseDests[it.index()] == op.getDefaultDestination() &&
|
||||
op.getCaseOperands(it.index()) == op.getDefaultOperands()) {
|
||||
requiresChange = true;
|
||||
continue;
|
||||
}
|
||||
newCaseDestinations.push_back(caseDests[it.index()]);
|
||||
newCaseOperands.push_back(op.getCaseOperands(it.index()));
|
||||
newCaseValues.push_back(it.value());
|
||||
}
|
||||
|
||||
if (!requiresChange)
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<SwitchOp>(
|
||||
op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
|
||||
newCaseValues, newCaseDestinations, newCaseOperands);
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Helper for folding a switch with a constant value.
|
||||
/// switch %c_42 : i32, [
|
||||
/// default: ^bb1 ,
|
||||
/// 42: ^bb2,
|
||||
/// 43: ^bb3
|
||||
/// ]
|
||||
/// -> br ^bb2
|
||||
static void foldSwitch(SwitchOp op, PatternRewriter &rewriter,
|
||||
const APInt &caseValue) {
|
||||
auto caseValues = op.getCaseValues();
|
||||
for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
|
||||
if (it.value() == caseValue) {
|
||||
rewriter.replaceOpWithNewOp<BranchOp>(
|
||||
op, op.getCaseDestinations()[it.index()],
|
||||
op.getCaseOperands(it.index()));
|
||||
return;
|
||||
}
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
|
||||
op.getDefaultOperands());
|
||||
}
|
||||
|
||||
/// switch %c_42 : i32, [
|
||||
/// default: ^bb1,
|
||||
/// 42: ^bb2,
|
||||
/// 43: ^bb3
|
||||
/// ]
|
||||
/// -> br ^bb2
|
||||
static LogicalResult simplifyConstSwitchValue(SwitchOp op,
|
||||
PatternRewriter &rewriter) {
|
||||
APInt caseValue;
|
||||
if (!matchPattern(op.getFlag(), m_ConstantInt(&caseValue)))
|
||||
return failure();
|
||||
|
||||
foldSwitch(op, rewriter, caseValue);
|
||||
return success();
|
||||
}
|
||||
|
||||
/// switch %c_42 : i32, [
|
||||
/// default: ^bb1,
|
||||
/// 42: ^bb2,
|
||||
/// ]
|
||||
/// ^bb2:
|
||||
/// br ^bb3
|
||||
/// ->
|
||||
/// switch %c_42 : i32, [
|
||||
/// default: ^bb1,
|
||||
/// 42: ^bb3,
|
||||
/// ]
|
||||
static LogicalResult simplifyPassThroughSwitch(SwitchOp op,
|
||||
PatternRewriter &rewriter) {
|
||||
SmallVector<Block *> newCaseDests;
|
||||
SmallVector<ValueRange> newCaseOperands;
|
||||
SmallVector<SmallVector<Value>> argStorage;
|
||||
auto caseValues = op.getCaseValues();
|
||||
auto caseDests = op.getCaseDestinations();
|
||||
bool requiresChange = false;
|
||||
for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
|
||||
Block *caseDest = caseDests[i];
|
||||
ValueRange caseOperands = op.getCaseOperands(i);
|
||||
argStorage.emplace_back();
|
||||
if (succeeded(collapseBranch(caseDest, caseOperands, argStorage.back())))
|
||||
requiresChange = true;
|
||||
|
||||
newCaseDests.push_back(caseDest);
|
||||
newCaseOperands.push_back(caseOperands);
|
||||
}
|
||||
|
||||
Block *defaultDest = op.getDefaultDestination();
|
||||
ValueRange defaultOperands = op.getDefaultOperands();
|
||||
argStorage.emplace_back();
|
||||
|
||||
if (succeeded(
|
||||
collapseBranch(defaultDest, defaultOperands, argStorage.back())))
|
||||
requiresChange = true;
|
||||
|
||||
if (!requiresChange)
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<SwitchOp>(op, op.getFlag(), defaultDest,
|
||||
defaultOperands, caseValues.getValue(),
|
||||
newCaseDests, newCaseOperands);
|
||||
return success();
|
||||
}
|
||||
|
||||
/// switch %flag : i32, [
|
||||
/// default: ^bb1,
|
||||
/// 42: ^bb2,
|
||||
/// ]
|
||||
/// ^bb2:
|
||||
/// switch %flag : i32, [
|
||||
/// default: ^bb3,
|
||||
/// 42: ^bb4
|
||||
/// ]
|
||||
/// ->
|
||||
/// switch %flag : i32, [
|
||||
/// default: ^bb1,
|
||||
/// 42: ^bb2,
|
||||
/// ]
|
||||
/// ^bb2:
|
||||
/// br ^bb4
|
||||
///
|
||||
/// and
|
||||
///
|
||||
/// switch %flag : i32, [
|
||||
/// default: ^bb1,
|
||||
/// 42: ^bb2,
|
||||
/// ]
|
||||
/// ^bb2:
|
||||
/// switch %flag : i32, [
|
||||
/// default: ^bb3,
|
||||
/// 43: ^bb4
|
||||
/// ]
|
||||
/// ->
|
||||
/// switch %flag : i32, [
|
||||
/// default: ^bb1,
|
||||
/// 42: ^bb2,
|
||||
/// ]
|
||||
/// ^bb2:
|
||||
/// br ^bb3
|
||||
static LogicalResult
|
||||
simplifySwitchFromSwitchOnSameCondition(SwitchOp op,
|
||||
PatternRewriter &rewriter) {
|
||||
// Check that we have a single distinct predecessor.
|
||||
Block *currentBlock = op->getBlock();
|
||||
Block *predecessor = currentBlock->getSinglePredecessor();
|
||||
if (!predecessor)
|
||||
return failure();
|
||||
|
||||
// Check that the predecessor terminates with a switch branch to this block
|
||||
// and that it branches on the same condition and that this branch isn't the
|
||||
// default destination.
|
||||
auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
|
||||
if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
|
||||
predSwitch.getDefaultDestination() == currentBlock)
|
||||
return failure();
|
||||
|
||||
// Fold this switch to an unconditional branch.
|
||||
SuccessorRange predDests = predSwitch.getCaseDestinations();
|
||||
auto it = llvm::find(predDests, currentBlock);
|
||||
if (it != predDests.end()) {
|
||||
Optional<DenseIntElementsAttr> predCaseValues = predSwitch.getCaseValues();
|
||||
foldSwitch(op, rewriter,
|
||||
predCaseValues->getValues<APInt>()[it - predDests.begin()]);
|
||||
} else {
|
||||
rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
|
||||
op.getDefaultOperands());
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
/// switch %flag : i32, [
|
||||
/// default: ^bb1,
|
||||
/// 42: ^bb2
|
||||
/// ]
|
||||
/// ^bb1:
|
||||
/// switch %flag : i32, [
|
||||
/// default: ^bb3,
|
||||
/// 42: ^bb4,
|
||||
/// 43: ^bb5
|
||||
/// ]
|
||||
/// ->
|
||||
/// switch %flag : i32, [
|
||||
/// default: ^bb1,
|
||||
/// 42: ^bb2,
|
||||
/// ]
|
||||
/// ^bb1:
|
||||
/// switch %flag : i32, [
|
||||
/// default: ^bb3,
|
||||
/// 43: ^bb5
|
||||
/// ]
|
||||
static LogicalResult
|
||||
simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op,
|
||||
PatternRewriter &rewriter) {
|
||||
// Check that we have a single distinct predecessor.
|
||||
Block *currentBlock = op->getBlock();
|
||||
Block *predecessor = currentBlock->getSinglePredecessor();
|
||||
if (!predecessor)
|
||||
return failure();
|
||||
|
||||
// Check that the predecessor terminates with a switch branch to this block
|
||||
// and that it branches on the same condition and that this branch is the
|
||||
// default destination.
|
||||
auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
|
||||
if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
|
||||
predSwitch.getDefaultDestination() != currentBlock)
|
||||
return failure();
|
||||
|
||||
// Delete case values that are not possible here.
|
||||
DenseSet<APInt> caseValuesToRemove;
|
||||
auto predDests = predSwitch.getCaseDestinations();
|
||||
auto predCaseValues = predSwitch.getCaseValues();
|
||||
for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i)
|
||||
if (currentBlock != predDests[i])
|
||||
caseValuesToRemove.insert(predCaseValues->getValues<APInt>()[i]);
|
||||
|
||||
SmallVector<Block *> newCaseDestinations;
|
||||
SmallVector<ValueRange> newCaseOperands;
|
||||
SmallVector<APInt> newCaseValues;
|
||||
bool requiresChange = false;
|
||||
|
||||
auto caseValues = op.getCaseValues();
|
||||
auto caseDests = op.getCaseDestinations();
|
||||
for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
|
||||
if (caseValuesToRemove.contains(it.value())) {
|
||||
requiresChange = true;
|
||||
continue;
|
||||
}
|
||||
newCaseDestinations.push_back(caseDests[it.index()]);
|
||||
newCaseOperands.push_back(op.getCaseOperands(it.index()));
|
||||
newCaseValues.push_back(it.value());
|
||||
}
|
||||
|
||||
if (!requiresChange)
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<SwitchOp>(
|
||||
op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
|
||||
newCaseValues, newCaseDestinations, newCaseOperands);
|
||||
return success();
|
||||
}
|
||||
|
||||
void SwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
MLIRContext *context) {
|
||||
results.add(&simplifySwitchWithOnlyDefault)
|
||||
.add(&dropSwitchCasesThatMatchDefault)
|
||||
.add(&simplifyConstSwitchValue)
|
||||
.add(&simplifyPassThroughSwitch)
|
||||
.add(&simplifySwitchFromSwitchOnSameCondition)
|
||||
.add(&simplifySwitchFromDefaultSwitchOnSameCondition);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TableGen'd op method definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
|
|
@ -12,10 +12,10 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
#include "mlir/Dialect/GPU/GPUDialect.h"
|
||||
#include "mlir/Dialect/GPU/Passes.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
@ -44,14 +44,14 @@ struct GpuAllReduceRewriter {
|
|||
/// workgroup memory.
|
||||
///
|
||||
/// %subgroup_reduce = `createSubgroupReduce(%operand)`
|
||||
/// cond_br %is_first_lane, ^then1, ^continue1
|
||||
/// cf.cond_br %is_first_lane, ^then1, ^continue1
|
||||
/// ^then1:
|
||||
/// store %subgroup_reduce, %workgroup_buffer[%subgroup_id]
|
||||
/// br ^continue1
|
||||
/// cf.br ^continue1
|
||||
/// ^continue1:
|
||||
/// gpu.barrier
|
||||
/// %is_valid_subgroup = arith.cmpi "slt" %invocation_idx, %num_subgroups
|
||||
/// cond_br %is_valid_subgroup, ^then2, ^continue2
|
||||
/// cf.cond_br %is_valid_subgroup, ^then2, ^continue2
|
||||
/// ^then2:
|
||||
/// %partial_reduce = load %workgroup_buffer[%invocation_idx]
|
||||
/// %all_reduce = `createSubgroupReduce(%partial_reduce)`
|
||||
|
@ -194,7 +194,7 @@ private:
|
|||
|
||||
// Add branch before inserted body, into body.
|
||||
block = block->getNextNode();
|
||||
create<BranchOp>(block, ValueRange());
|
||||
create<cf::BranchOp>(block, ValueRange());
|
||||
|
||||
// Replace all gpu.yield ops with branch out of body.
|
||||
for (; block != split; block = block->getNextNode()) {
|
||||
|
@ -202,7 +202,7 @@ private:
|
|||
if (!isa<gpu::YieldOp>(terminator))
|
||||
continue;
|
||||
rewriter.setInsertionPointToEnd(block);
|
||||
rewriter.replaceOpWithNewOp<BranchOp>(
|
||||
rewriter.replaceOpWithNewOp<cf::BranchOp>(
|
||||
terminator, split, ValueRange(terminator->getOperand(0)));
|
||||
}
|
||||
|
||||
|
@ -285,17 +285,17 @@ private:
|
|||
Block *continueBlock = rewriter.splitBlock(elseBlock, elseBlock->begin());
|
||||
|
||||
rewriter.setInsertionPointToEnd(currentBlock);
|
||||
create<CondBranchOp>(condition, thenBlock,
|
||||
/*trueOperands=*/ArrayRef<Value>(), elseBlock,
|
||||
/*falseOperands=*/ArrayRef<Value>());
|
||||
create<cf::CondBranchOp>(condition, thenBlock,
|
||||
/*trueOperands=*/ArrayRef<Value>(), elseBlock,
|
||||
/*falseOperands=*/ArrayRef<Value>());
|
||||
|
||||
rewriter.setInsertionPointToStart(thenBlock);
|
||||
auto thenOperands = thenOpsFactory();
|
||||
create<BranchOp>(continueBlock, thenOperands);
|
||||
create<cf::BranchOp>(continueBlock, thenOperands);
|
||||
|
||||
rewriter.setInsertionPointToStart(elseBlock);
|
||||
auto elseOperands = elseOpsFactory();
|
||||
create<BranchOp>(continueBlock, elseOperands);
|
||||
create<cf::BranchOp>(continueBlock, elseOperands);
|
||||
|
||||
assert(thenOperands.size() == elseOperands.size());
|
||||
rewriter.setInsertionPointToStart(continueBlock);
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
|
||||
#include "PassDetail.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
#include "mlir/Dialect/DLTI/DLTI.h"
|
||||
#include "mlir/Dialect/GPU/GPUDialect.h"
|
||||
#include "mlir/Dialect/GPU/Passes.h"
|
||||
|
@ -186,7 +187,7 @@ static gpu::GPUFuncOp outlineKernelFuncImpl(gpu::LaunchOp launchOp,
|
|||
Block &launchOpEntry = launchOpBody.front();
|
||||
Block *clonedLaunchOpEntry = map.lookup(&launchOpEntry);
|
||||
builder.setInsertionPointToEnd(&entryBlock);
|
||||
builder.create<BranchOp>(loc, clonedLaunchOpEntry);
|
||||
builder.create<cf::BranchOp>(loc, clonedLaunchOpEntry);
|
||||
|
||||
outlinedFunc.walk([](gpu::TerminatorOp op) {
|
||||
OpBuilder replacer(op);
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "PassDetail.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Linalg/Passes.h"
|
||||
#include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
|
||||
|
@ -254,13 +255,13 @@ struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
|
|||
DenseSet<BlockArgument> &blockArgsToDetensor) override {
|
||||
SmallVector<Value> workList;
|
||||
|
||||
func->walk([&](CondBranchOp condBr) {
|
||||
func->walk([&](cf::CondBranchOp condBr) {
|
||||
for (auto operand : condBr.getOperands()) {
|
||||
workList.push_back(operand);
|
||||
}
|
||||
});
|
||||
|
||||
func->walk([&](BranchOp br) {
|
||||
func->walk([&](cf::BranchOp br) {
|
||||
for (auto operand : br.getOperands()) {
|
||||
workList.push_back(operand);
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
|
@ -165,13 +166,13 @@ struct SingleBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
|
|||
// "test.foo"() : () -> ()
|
||||
// %v = scf.execute_region -> i64 {
|
||||
// %c = "test.cmp"() : () -> i1
|
||||
// cond_br %c, ^bb2, ^bb3
|
||||
// cf.cond_br %c, ^bb2, ^bb3
|
||||
// ^bb2:
|
||||
// %x = "test.val1"() : () -> i64
|
||||
// br ^bb4(%x : i64)
|
||||
// cf.br ^bb4(%x : i64)
|
||||
// ^bb3:
|
||||
// %y = "test.val2"() : () -> i64
|
||||
// br ^bb4(%y : i64)
|
||||
// cf.br ^bb4(%y : i64)
|
||||
// ^bb4(%z : i64):
|
||||
// scf.yield %z : i64
|
||||
// }
|
||||
|
@ -184,13 +185,13 @@ struct SingleBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
|
|||
// func @func_execute_region_elim() {
|
||||
// "test.foo"() : () -> ()
|
||||
// %c = "test.cmp"() : () -> i1
|
||||
// cond_br %c, ^bb1, ^bb2
|
||||
// cf.cond_br %c, ^bb1, ^bb2
|
||||
// ^bb1: // pred: ^bb0
|
||||
// %x = "test.val1"() : () -> i64
|
||||
// br ^bb3(%x : i64)
|
||||
// cf.br ^bb3(%x : i64)
|
||||
// ^bb2: // pred: ^bb0
|
||||
// %y = "test.val2"() : () -> i64
|
||||
// br ^bb3(%y : i64)
|
||||
// cf.br ^bb3(%y : i64)
|
||||
// ^bb3(%z: i64): // 2 preds: ^bb1, ^bb2
|
||||
// "test.bar"(%z) : (i64) -> ()
|
||||
// return
|
||||
|
@ -208,13 +209,13 @@ struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
|
|||
Block *postBlock = rewriter.splitBlock(prevBlock, op->getIterator());
|
||||
rewriter.setInsertionPointToEnd(prevBlock);
|
||||
|
||||
rewriter.create<BranchOp>(op.getLoc(), &op.getRegion().front());
|
||||
rewriter.create<cf::BranchOp>(op.getLoc(), &op.getRegion().front());
|
||||
|
||||
for (Block &blk : op.getRegion()) {
|
||||
if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) {
|
||||
rewriter.setInsertionPoint(yieldOp);
|
||||
rewriter.create<BranchOp>(yieldOp.getLoc(), postBlock,
|
||||
yieldOp.getResults());
|
||||
rewriter.create<cf::BranchOp>(yieldOp.getLoc(), postBlock,
|
||||
yieldOp.getResults());
|
||||
rewriter.eraseOp(yieldOp);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -13,7 +13,7 @@ add_mlir_dialect_library(MLIRSparseTensorPipelines
|
|||
MLIRMemRefToLLVM
|
||||
MLIRPass
|
||||
MLIRReconcileUnrealizedCasts
|
||||
MLIRSCFToStandard
|
||||
MLIRSCFToControlFlow
|
||||
MLIRSparseTensor
|
||||
MLIRSparseTensorTransforms
|
||||
MLIRStandardOpsTransforms
|
||||
|
|
|
@ -33,7 +33,7 @@ void mlir::sparse_tensor::buildSparseCompiler(
|
|||
pm.addNestedPass<FuncOp>(createLinalgBufferizePass());
|
||||
pm.addNestedPass<FuncOp>(createConvertLinalgToLoopsPass());
|
||||
pm.addNestedPass<FuncOp>(createConvertVectorToSCFPass());
|
||||
pm.addPass(createLowerToCFGPass()); // --convert-scf-to-std
|
||||
pm.addNestedPass<FuncOp>(createConvertSCFToCFPass());
|
||||
pm.addPass(createFuncBufferizePass());
|
||||
pm.addPass(arith::createConstantBufferizePass());
|
||||
pm.addNestedPass<FuncOp>(createTensorBufferizePass());
|
||||
|
|
|
@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRStandard
|
|||
MLIRArithmetic
|
||||
MLIRCallInterfaces
|
||||
MLIRCastInterfaces
|
||||
MLIRControlFlow
|
||||
MLIRControlFlowInterfaces
|
||||
MLIRInferTypeOpInterface
|
||||
MLIRIR
|
||||
|
|
|
@ -8,9 +8,8 @@
|
|||
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
|
||||
#include "mlir/Dialect/CommonFolders.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
|
@ -77,7 +76,7 @@ struct StdInlinerInterface : public DialectInlinerInterface {
|
|||
|
||||
// Replace the return with a branch to the dest.
|
||||
OpBuilder builder(op);
|
||||
builder.create<BranchOp>(op->getLoc(), newDest, returnOp.getOperands());
|
||||
builder.create<cf::BranchOp>(op->getLoc(), newDest, returnOp.getOperands());
|
||||
op->erase();
|
||||
}
|
||||
|
||||
|
@ -121,130 +120,6 @@ Operation *StandardOpsDialect::materializeConstant(OpBuilder &builder,
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AssertOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) {
|
||||
// Erase assertion if argument is constant true.
|
||||
if (matchPattern(op.getArg(), m_One())) {
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BranchOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Given a successor, try to collapse it to a new destination if it only
|
||||
/// contains a passthrough unconditional branch. If the successor is
|
||||
/// collapsable, `successor` and `successorOperands` are updated to reference
|
||||
/// the new destination and values. `argStorage` is used as storage if operands
|
||||
/// to the collapsed successor need to be remapped. It must outlive uses of
|
||||
/// successorOperands.
|
||||
static LogicalResult collapseBranch(Block *&successor,
|
||||
ValueRange &successorOperands,
|
||||
SmallVectorImpl<Value> &argStorage) {
|
||||
// Check that the successor only contains a unconditional branch.
|
||||
if (std::next(successor->begin()) != successor->end())
|
||||
return failure();
|
||||
// Check that the terminator is an unconditional branch.
|
||||
BranchOp successorBranch = dyn_cast<BranchOp>(successor->getTerminator());
|
||||
if (!successorBranch)
|
||||
return failure();
|
||||
// Check that the arguments are only used within the terminator.
|
||||
for (BlockArgument arg : successor->getArguments()) {
|
||||
for (Operation *user : arg.getUsers())
|
||||
if (user != successorBranch)
|
||||
return failure();
|
||||
}
|
||||
// Don't try to collapse branches to infinite loops.
|
||||
Block *successorDest = successorBranch.getDest();
|
||||
if (successorDest == successor)
|
||||
return failure();
|
||||
|
||||
// Update the operands to the successor. If the branch parent has no
|
||||
// arguments, we can use the branch operands directly.
|
||||
OperandRange operands = successorBranch.getOperands();
|
||||
if (successor->args_empty()) {
|
||||
successor = successorDest;
|
||||
successorOperands = operands;
|
||||
return success();
|
||||
}
|
||||
|
||||
// Otherwise, we need to remap any argument operands.
|
||||
for (Value operand : operands) {
|
||||
BlockArgument argOperand = operand.dyn_cast<BlockArgument>();
|
||||
if (argOperand && argOperand.getOwner() == successor)
|
||||
argStorage.push_back(successorOperands[argOperand.getArgNumber()]);
|
||||
else
|
||||
argStorage.push_back(operand);
|
||||
}
|
||||
successor = successorDest;
|
||||
successorOperands = argStorage;
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Simplify a branch to a block that has a single predecessor. This effectively
|
||||
/// merges the two blocks.
|
||||
static LogicalResult
|
||||
simplifyBrToBlockWithSinglePred(BranchOp op, PatternRewriter &rewriter) {
|
||||
// Check that the successor block has a single predecessor.
|
||||
Block *succ = op.getDest();
|
||||
Block *opParent = op->getBlock();
|
||||
if (succ == opParent || !llvm::hasSingleElement(succ->getPredecessors()))
|
||||
return failure();
|
||||
|
||||
// Merge the successor into the current block and erase the branch.
|
||||
rewriter.mergeBlocks(succ, opParent, op.getOperands());
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
|
||||
/// br ^bb1
|
||||
/// ^bb1
|
||||
/// br ^bbN(...)
|
||||
///
|
||||
/// -> br ^bbN(...)
|
||||
///
|
||||
static LogicalResult simplifyPassThroughBr(BranchOp op,
|
||||
PatternRewriter &rewriter) {
|
||||
Block *dest = op.getDest();
|
||||
ValueRange destOperands = op.getOperands();
|
||||
SmallVector<Value, 4> destOperandStorage;
|
||||
|
||||
// Try to collapse the successor if it points somewhere other than this
|
||||
// block.
|
||||
if (dest == op->getBlock() ||
|
||||
failed(collapseBranch(dest, destOperands, destOperandStorage)))
|
||||
return failure();
|
||||
|
||||
// Create a new branch with the collapsed successor.
|
||||
rewriter.replaceOpWithNewOp<BranchOp>(op, dest, destOperands);
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult BranchOp::canonicalize(BranchOp op, PatternRewriter &rewriter) {
|
||||
return success(succeeded(simplifyBrToBlockWithSinglePred(op, rewriter)) ||
|
||||
succeeded(simplifyPassThroughBr(op, rewriter)));
|
||||
}
|
||||
|
||||
void BranchOp::setDest(Block *block) { return setSuccessor(block); }
|
||||
|
||||
void BranchOp::eraseOperand(unsigned index) { (*this)->eraseOperand(index); }
|
||||
|
||||
Optional<MutableOperandRange>
|
||||
BranchOp::getMutableSuccessorOperands(unsigned index) {
|
||||
assert(index == 0 && "invalid successor index");
|
||||
return getDestOperandsMutable();
|
||||
}
|
||||
|
||||
Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) {
|
||||
return getDest();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CallOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -307,260 +182,6 @@ LogicalResult CallIndirectOp::canonicalize(CallIndirectOp indirectCall,
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CondBranchOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
/// cond_br true, ^bb1, ^bb2
|
||||
/// -> br ^bb1
|
||||
/// cond_br false, ^bb1, ^bb2
|
||||
/// -> br ^bb2
|
||||
///
|
||||
struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
|
||||
using OpRewritePattern<CondBranchOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(CondBranchOp condbr,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (matchPattern(condbr.getCondition(), m_NonZero())) {
|
||||
// True branch taken.
|
||||
rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
|
||||
condbr.getTrueOperands());
|
||||
return success();
|
||||
}
|
||||
if (matchPattern(condbr.getCondition(), m_Zero())) {
|
||||
// False branch taken.
|
||||
rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
|
||||
condbr.getFalseOperands());
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
/// cond_br %cond, ^bb1, ^bb2
|
||||
/// ^bb1
|
||||
/// br ^bbN(...)
|
||||
/// ^bb2
|
||||
/// br ^bbK(...)
|
||||
///
|
||||
/// -> cond_br %cond, ^bbN(...), ^bbK(...)
|
||||
///
|
||||
struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> {
|
||||
using OpRewritePattern<CondBranchOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(CondBranchOp condbr,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Block *trueDest = condbr.getTrueDest(), *falseDest = condbr.getFalseDest();
|
||||
ValueRange trueDestOperands = condbr.getTrueOperands();
|
||||
ValueRange falseDestOperands = condbr.getFalseOperands();
|
||||
SmallVector<Value, 4> trueDestOperandStorage, falseDestOperandStorage;
|
||||
|
||||
// Try to collapse one of the current successors.
|
||||
LogicalResult collapsedTrue =
|
||||
collapseBranch(trueDest, trueDestOperands, trueDestOperandStorage);
|
||||
LogicalResult collapsedFalse =
|
||||
collapseBranch(falseDest, falseDestOperands, falseDestOperandStorage);
|
||||
if (failed(collapsedTrue) && failed(collapsedFalse))
|
||||
return failure();
|
||||
|
||||
// Create a new branch with the collapsed successors.
|
||||
rewriter.replaceOpWithNewOp<CondBranchOp>(condbr, condbr.getCondition(),
|
||||
trueDest, trueDestOperands,
|
||||
falseDest, falseDestOperands);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N)
|
||||
/// -> br ^bb1(A, ..., N)
|
||||
///
|
||||
/// cond_br %cond, ^bb1(A), ^bb1(B)
|
||||
/// -> %select = arith.select %cond, A, B
|
||||
/// br ^bb1(%select)
|
||||
///
|
||||
struct SimplifyCondBranchIdenticalSuccessors
|
||||
: public OpRewritePattern<CondBranchOp> {
|
||||
using OpRewritePattern<CondBranchOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(CondBranchOp condbr,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Check that the true and false destinations are the same and have the same
|
||||
// operands.
|
||||
Block *trueDest = condbr.getTrueDest();
|
||||
if (trueDest != condbr.getFalseDest())
|
||||
return failure();
|
||||
|
||||
// If all of the operands match, no selects need to be generated.
|
||||
OperandRange trueOperands = condbr.getTrueOperands();
|
||||
OperandRange falseOperands = condbr.getFalseOperands();
|
||||
if (trueOperands == falseOperands) {
|
||||
rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, trueOperands);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Otherwise, if the current block is the only predecessor insert selects
|
||||
// for any mismatched branch operands.
|
||||
if (trueDest->getUniquePredecessor() != condbr->getBlock())
|
||||
return failure();
|
||||
|
||||
// Generate a select for any operands that differ between the two.
|
||||
SmallVector<Value, 8> mergedOperands;
|
||||
mergedOperands.reserve(trueOperands.size());
|
||||
Value condition = condbr.getCondition();
|
||||
for (auto it : llvm::zip(trueOperands, falseOperands)) {
|
||||
if (std::get<0>(it) == std::get<1>(it))
|
||||
mergedOperands.push_back(std::get<0>(it));
|
||||
else
|
||||
mergedOperands.push_back(rewriter.create<arith::SelectOp>(
|
||||
condbr.getLoc(), condition, std::get<0>(it), std::get<1>(it)));
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, mergedOperands);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// ...
|
||||
/// cond_br %cond, ^bb1(...), ^bb2(...)
|
||||
/// ...
|
||||
/// ^bb1: // has single predecessor
|
||||
/// ...
|
||||
/// cond_br %cond, ^bb3(...), ^bb4(...)
|
||||
///
|
||||
/// ->
|
||||
///
|
||||
/// ...
|
||||
/// cond_br %cond, ^bb1(...), ^bb2(...)
|
||||
/// ...
|
||||
/// ^bb1: // has single predecessor
|
||||
/// ...
|
||||
/// br ^bb3(...)
|
||||
///
|
||||
struct SimplifyCondBranchFromCondBranchOnSameCondition
|
||||
: public OpRewritePattern<CondBranchOp> {
|
||||
using OpRewritePattern<CondBranchOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(CondBranchOp condbr,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Check that we have a single distinct predecessor.
|
||||
Block *currentBlock = condbr->getBlock();
|
||||
Block *predecessor = currentBlock->getSinglePredecessor();
|
||||
if (!predecessor)
|
||||
return failure();
|
||||
|
||||
// Check that the predecessor terminates with a conditional branch to this
|
||||
// block and that it branches on the same condition.
|
||||
auto predBranch = dyn_cast<CondBranchOp>(predecessor->getTerminator());
|
||||
if (!predBranch || condbr.getCondition() != predBranch.getCondition())
|
||||
return failure();
|
||||
|
||||
// Fold this branch to an unconditional branch.
|
||||
if (currentBlock == predBranch.getTrueDest())
|
||||
rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
|
||||
condbr.getTrueDestOperands());
|
||||
else
|
||||
rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
|
||||
condbr.getFalseDestOperands());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// cond_br %arg0, ^trueB, ^falseB
|
||||
///
|
||||
/// ^trueB:
|
||||
/// "test.consumer1"(%arg0) : (i1) -> ()
|
||||
/// ...
|
||||
///
|
||||
/// ^falseB:
|
||||
/// "test.consumer2"(%arg0) : (i1) -> ()
|
||||
/// ...
|
||||
///
|
||||
/// ->
|
||||
///
|
||||
/// cond_br %arg0, ^trueB, ^falseB
|
||||
/// ^trueB:
|
||||
/// "test.consumer1"(%true) : (i1) -> ()
|
||||
/// ...
|
||||
///
|
||||
/// ^falseB:
|
||||
/// "test.consumer2"(%false) : (i1) -> ()
|
||||
/// ...
|
||||
struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> {
|
||||
using OpRewritePattern<CondBranchOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(CondBranchOp condbr,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Check that we have a single distinct predecessor.
|
||||
bool replaced = false;
|
||||
Type ty = rewriter.getI1Type();
|
||||
|
||||
// These variables serve to prevent creating duplicate constants
|
||||
// and hold constant true or false values.
|
||||
Value constantTrue = nullptr;
|
||||
Value constantFalse = nullptr;
|
||||
|
||||
// TODO These checks can be expanded to encompas any use with only
|
||||
// either the true of false edge as a predecessor. For now, we fall
|
||||
// back to checking the single predecessor is given by the true/fasle
|
||||
// destination, thereby ensuring that only that edge can reach the
|
||||
// op.
|
||||
if (condbr.getTrueDest()->getSinglePredecessor()) {
|
||||
for (OpOperand &use :
|
||||
llvm::make_early_inc_range(condbr.getCondition().getUses())) {
|
||||
if (use.getOwner()->getBlock() == condbr.getTrueDest()) {
|
||||
replaced = true;
|
||||
|
||||
if (!constantTrue)
|
||||
constantTrue = rewriter.create<arith::ConstantOp>(
|
||||
condbr.getLoc(), ty, rewriter.getBoolAttr(true));
|
||||
|
||||
rewriter.updateRootInPlace(use.getOwner(),
|
||||
[&] { use.set(constantTrue); });
|
||||
}
|
||||
}
|
||||
}
|
||||
if (condbr.getFalseDest()->getSinglePredecessor()) {
|
||||
for (OpOperand &use :
|
||||
llvm::make_early_inc_range(condbr.getCondition().getUses())) {
|
||||
if (use.getOwner()->getBlock() == condbr.getFalseDest()) {
|
||||
replaced = true;
|
||||
|
||||
if (!constantFalse)
|
||||
constantFalse = rewriter.create<arith::ConstantOp>(
|
||||
condbr.getLoc(), ty, rewriter.getBoolAttr(false));
|
||||
|
||||
rewriter.updateRootInPlace(use.getOwner(),
|
||||
[&] { use.set(constantFalse); });
|
||||
}
|
||||
}
|
||||
}
|
||||
return success(replaced);
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
MLIRContext *context) {
|
||||
results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
|
||||
SimplifyCondBranchIdenticalSuccessors,
|
||||
SimplifyCondBranchFromCondBranchOnSameCondition,
|
||||
CondBranchTruthPropagation>(context);
|
||||
}
|
||||
|
||||
Optional<MutableOperandRange>
|
||||
CondBranchOp::getMutableSuccessorOperands(unsigned index) {
|
||||
assert(index < getNumSuccessors() && "invalid successor index");
|
||||
return index == trueIndex ? getTrueDestOperandsMutable()
|
||||
: getFalseDestOperandsMutable();
|
||||
}
|
||||
|
||||
Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
|
||||
if (IntegerAttr condAttr = operands.front().dyn_cast_or_null<IntegerAttr>())
|
||||
return condAttr.getValue().isOneValue() ? getTrueDest() : getFalseDest();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ConstantOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -621,439 +242,6 @@ LogicalResult ReturnOp::verify() {
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SwitchOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
|
||||
Block *defaultDestination, ValueRange defaultOperands,
|
||||
DenseIntElementsAttr caseValues,
|
||||
BlockRange caseDestinations,
|
||||
ArrayRef<ValueRange> caseOperands) {
|
||||
build(builder, result, value, defaultOperands, caseOperands, caseValues,
|
||||
defaultDestination, caseDestinations);
|
||||
}
|
||||
|
||||
void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
|
||||
Block *defaultDestination, ValueRange defaultOperands,
|
||||
ArrayRef<APInt> caseValues, BlockRange caseDestinations,
|
||||
ArrayRef<ValueRange> caseOperands) {
|
||||
DenseIntElementsAttr caseValuesAttr;
|
||||
if (!caseValues.empty()) {
|
||||
ShapedType caseValueType = VectorType::get(
|
||||
static_cast<int64_t>(caseValues.size()), value.getType());
|
||||
caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
|
||||
}
|
||||
build(builder, result, value, defaultDestination, defaultOperands,
|
||||
caseValuesAttr, caseDestinations, caseOperands);
|
||||
}
|
||||
|
||||
/// <cases> ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)?
|
||||
/// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )*
|
||||
static ParseResult parseSwitchOpCases(
|
||||
OpAsmParser &parser, Type &flagType, Block *&defaultDestination,
|
||||
SmallVectorImpl<OpAsmParser::OperandType> &defaultOperands,
|
||||
SmallVectorImpl<Type> &defaultOperandTypes,
|
||||
DenseIntElementsAttr &caseValues,
|
||||
SmallVectorImpl<Block *> &caseDestinations,
|
||||
SmallVectorImpl<SmallVector<OpAsmParser::OperandType>> &caseOperands,
|
||||
SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) {
|
||||
if (parser.parseKeyword("default") || parser.parseColon() ||
|
||||
parser.parseSuccessor(defaultDestination))
|
||||
return failure();
|
||||
if (succeeded(parser.parseOptionalLParen())) {
|
||||
if (parser.parseRegionArgumentList(defaultOperands) ||
|
||||
parser.parseColonTypeList(defaultOperandTypes) || parser.parseRParen())
|
||||
return failure();
|
||||
}
|
||||
|
||||
SmallVector<APInt> values;
|
||||
unsigned bitWidth = flagType.getIntOrFloatBitWidth();
|
||||
while (succeeded(parser.parseOptionalComma())) {
|
||||
int64_t value = 0;
|
||||
if (failed(parser.parseInteger(value)))
|
||||
return failure();
|
||||
values.push_back(APInt(bitWidth, value));
|
||||
|
||||
Block *destination;
|
||||
SmallVector<OpAsmParser::OperandType> operands;
|
||||
SmallVector<Type> operandTypes;
|
||||
if (failed(parser.parseColon()) ||
|
||||
failed(parser.parseSuccessor(destination)))
|
||||
return failure();
|
||||
if (succeeded(parser.parseOptionalLParen())) {
|
||||
if (failed(parser.parseRegionArgumentList(operands)) ||
|
||||
failed(parser.parseColonTypeList(operandTypes)) ||
|
||||
failed(parser.parseRParen()))
|
||||
return failure();
|
||||
}
|
||||
caseDestinations.push_back(destination);
|
||||
caseOperands.emplace_back(operands);
|
||||
caseOperandTypes.emplace_back(operandTypes);
|
||||
}
|
||||
|
||||
if (!values.empty()) {
|
||||
ShapedType caseValueType =
|
||||
VectorType::get(static_cast<int64_t>(values.size()), flagType);
|
||||
caseValues = DenseIntElementsAttr::get(caseValueType, values);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
static void printSwitchOpCases(
|
||||
OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination,
|
||||
OperandRange defaultOperands, TypeRange defaultOperandTypes,
|
||||
DenseIntElementsAttr caseValues, SuccessorRange caseDestinations,
|
||||
OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes) {
|
||||
p << " default: ";
|
||||
p.printSuccessorAndUseList(defaultDestination, defaultOperands);
|
||||
|
||||
if (!caseValues)
|
||||
return;
|
||||
|
||||
for (const auto &it : llvm::enumerate(caseValues.getValues<APInt>())) {
|
||||
p << ',';
|
||||
p.printNewline();
|
||||
p << " ";
|
||||
p << it.value().getLimitedValue();
|
||||
p << ": ";
|
||||
p.printSuccessorAndUseList(caseDestinations[it.index()],
|
||||
caseOperands[it.index()]);
|
||||
}
|
||||
p.printNewline();
|
||||
}
|
||||
|
||||
LogicalResult SwitchOp::verify() {
|
||||
auto caseValues = getCaseValues();
|
||||
auto caseDestinations = getCaseDestinations();
|
||||
|
||||
if (!caseValues && caseDestinations.empty())
|
||||
return success();
|
||||
|
||||
Type flagType = getFlag().getType();
|
||||
Type caseValueType = caseValues->getType().getElementType();
|
||||
if (caseValueType != flagType)
|
||||
return emitOpError() << "'flag' type (" << flagType
|
||||
<< ") should match case value type (" << caseValueType
|
||||
<< ")";
|
||||
|
||||
if (caseValues &&
|
||||
caseValues->size() != static_cast<int64_t>(caseDestinations.size()))
|
||||
return emitOpError() << "number of case values (" << caseValues->size()
|
||||
<< ") should match number of "
|
||||
"case destinations ("
|
||||
<< caseDestinations.size() << ")";
|
||||
return success();
|
||||
}
|
||||
|
||||
Optional<MutableOperandRange>
|
||||
SwitchOp::getMutableSuccessorOperands(unsigned index) {
|
||||
assert(index < getNumSuccessors() && "invalid successor index");
|
||||
return index == 0 ? getDefaultOperandsMutable()
|
||||
: getCaseOperandsMutable(index - 1);
|
||||
}
|
||||
|
||||
Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
|
||||
Optional<DenseIntElementsAttr> caseValues = getCaseValues();
|
||||
|
||||
if (!caseValues)
|
||||
return getDefaultDestination();
|
||||
|
||||
SuccessorRange caseDests = getCaseDestinations();
|
||||
if (auto value = operands.front().dyn_cast_or_null<IntegerAttr>()) {
|
||||
for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>()))
|
||||
if (it.value() == value.getValue())
|
||||
return caseDests[it.index()];
|
||||
return getDefaultDestination();
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
/// switch %flag : i32, [
|
||||
/// default: ^bb1
|
||||
/// ]
|
||||
/// -> br ^bb1
|
||||
static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op,
|
||||
PatternRewriter &rewriter) {
|
||||
if (!op.getCaseDestinations().empty())
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
|
||||
op.getDefaultOperands());
|
||||
return success();
|
||||
}
|
||||
|
||||
/// switch %flag : i32, [
|
||||
/// default: ^bb1,
|
||||
/// 42: ^bb1,
|
||||
/// 43: ^bb2
|
||||
/// ]
|
||||
/// ->
|
||||
/// switch %flag : i32, [
|
||||
/// default: ^bb1,
|
||||
/// 43: ^bb2
|
||||
/// ]
|
||||
static LogicalResult
|
||||
dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter) {
|
||||
SmallVector<Block *> newCaseDestinations;
|
||||
SmallVector<ValueRange> newCaseOperands;
|
||||
SmallVector<APInt> newCaseValues;
|
||||
bool requiresChange = false;
|
||||
auto caseValues = op.getCaseValues();
|
||||
auto caseDests = op.getCaseDestinations();
|
||||
|
||||
for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
|
||||
if (caseDests[it.index()] == op.getDefaultDestination() &&
|
||||
op.getCaseOperands(it.index()) == op.getDefaultOperands()) {
|
||||
requiresChange = true;
|
||||
continue;
|
||||
}
|
||||
newCaseDestinations.push_back(caseDests[it.index()]);
|
||||
newCaseOperands.push_back(op.getCaseOperands(it.index()));
|
||||
newCaseValues.push_back(it.value());
|
||||
}
|
||||
|
||||
if (!requiresChange)
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<SwitchOp>(
|
||||
op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
|
||||
newCaseValues, newCaseDestinations, newCaseOperands);
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Helper for folding a switch with a constant value.
|
||||
/// switch %c_42 : i32, [
|
||||
/// default: ^bb1 ,
|
||||
/// 42: ^bb2,
|
||||
/// 43: ^bb3
|
||||
/// ]
|
||||
/// -> br ^bb2
|
||||
static void foldSwitch(SwitchOp op, PatternRewriter &rewriter,
|
||||
const APInt &caseValue) {
|
||||
auto caseValues = op.getCaseValues();
|
||||
for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
|
||||
if (it.value() == caseValue) {
|
||||
rewriter.replaceOpWithNewOp<BranchOp>(
|
||||
op, op.getCaseDestinations()[it.index()],
|
||||
op.getCaseOperands(it.index()));
|
||||
return;
|
||||
}
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
|
||||
op.getDefaultOperands());
|
||||
}
|
||||
|
||||
/// switch %c_42 : i32, [
|
||||
/// default: ^bb1,
|
||||
/// 42: ^bb2,
|
||||
/// 43: ^bb3
|
||||
/// ]
|
||||
/// -> br ^bb2
|
||||
static LogicalResult simplifyConstSwitchValue(SwitchOp op,
|
||||
PatternRewriter &rewriter) {
|
||||
APInt caseValue;
|
||||
if (!matchPattern(op.getFlag(), m_ConstantInt(&caseValue)))
|
||||
return failure();
|
||||
|
||||
foldSwitch(op, rewriter, caseValue);
|
||||
return success();
|
||||
}
|
||||
|
||||
/// switch %c_42 : i32, [
|
||||
/// default: ^bb1,
|
||||
/// 42: ^bb2,
|
||||
/// ]
|
||||
/// ^bb2:
|
||||
/// br ^bb3
|
||||
/// ->
|
||||
/// switch %c_42 : i32, [
|
||||
/// default: ^bb1,
|
||||
/// 42: ^bb3,
|
||||
/// ]
|
||||
static LogicalResult simplifyPassThroughSwitch(SwitchOp op,
|
||||
PatternRewriter &rewriter) {
|
||||
SmallVector<Block *> newCaseDests;
|
||||
SmallVector<ValueRange> newCaseOperands;
|
||||
SmallVector<SmallVector<Value>> argStorage;
|
||||
auto caseValues = op.getCaseValues();
|
||||
auto caseDests = op.getCaseDestinations();
|
||||
bool requiresChange = false;
|
||||
for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
|
||||
Block *caseDest = caseDests[i];
|
||||
ValueRange caseOperands = op.getCaseOperands(i);
|
||||
argStorage.emplace_back();
|
||||
if (succeeded(collapseBranch(caseDest, caseOperands, argStorage.back())))
|
||||
requiresChange = true;
|
||||
|
||||
newCaseDests.push_back(caseDest);
|
||||
newCaseOperands.push_back(caseOperands);
|
||||
}
|
||||
|
||||
Block *defaultDest = op.getDefaultDestination();
|
||||
ValueRange defaultOperands = op.getDefaultOperands();
|
||||
argStorage.emplace_back();
|
||||
|
||||
if (succeeded(
|
||||
collapseBranch(defaultDest, defaultOperands, argStorage.back())))
|
||||
requiresChange = true;
|
||||
|
||||
if (!requiresChange)
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<SwitchOp>(op, op.getFlag(), defaultDest,
|
||||
defaultOperands, caseValues.getValue(),
|
||||
newCaseDests, newCaseOperands);
|
||||
return success();
|
||||
}
|
||||
|
||||
/// switch %flag : i32, [
|
||||
/// default: ^bb1,
|
||||
/// 42: ^bb2,
|
||||
/// ]
|
||||
/// ^bb2:
|
||||
/// switch %flag : i32, [
|
||||
/// default: ^bb3,
|
||||
/// 42: ^bb4
|
||||
/// ]
|
||||
/// ->
|
||||
/// switch %flag : i32, [
|
||||
/// default: ^bb1,
|
||||
/// 42: ^bb2,
|
||||
/// ]
|
||||
/// ^bb2:
|
||||
/// br ^bb4
|
||||
///
|
||||
/// and
|
||||
///
|
||||
/// switch %flag : i32, [
|
||||
/// default: ^bb1,
|
||||
/// 42: ^bb2,
|
||||
/// ]
|
||||
/// ^bb2:
|
||||
/// switch %flag : i32, [
|
||||
/// default: ^bb3,
|
||||
/// 43: ^bb4
|
||||
/// ]
|
||||
/// ->
|
||||
/// switch %flag : i32, [
|
||||
/// default: ^bb1,
|
||||
/// 42: ^bb2,
|
||||
/// ]
|
||||
/// ^bb2:
|
||||
/// br ^bb3
|
||||
static LogicalResult
|
||||
simplifySwitchFromSwitchOnSameCondition(SwitchOp op,
|
||||
PatternRewriter &rewriter) {
|
||||
// Check that we have a single distinct predecessor.
|
||||
Block *currentBlock = op->getBlock();
|
||||
Block *predecessor = currentBlock->getSinglePredecessor();
|
||||
if (!predecessor)
|
||||
return failure();
|
||||
|
||||
// Check that the predecessor terminates with a switch branch to this block
|
||||
// and that it branches on the same condition and that this branch isn't the
|
||||
// default destination.
|
||||
auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
|
||||
if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
|
||||
predSwitch.getDefaultDestination() == currentBlock)
|
||||
return failure();
|
||||
|
||||
// Fold this switch to an unconditional branch.
|
||||
SuccessorRange predDests = predSwitch.getCaseDestinations();
|
||||
auto it = llvm::find(predDests, currentBlock);
|
||||
if (it != predDests.end()) {
|
||||
Optional<DenseIntElementsAttr> predCaseValues = predSwitch.getCaseValues();
|
||||
foldSwitch(op, rewriter,
|
||||
predCaseValues->getValues<APInt>()[it - predDests.begin()]);
|
||||
} else {
|
||||
rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
|
||||
op.getDefaultOperands());
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
/// switch %flag : i32, [
|
||||
/// default: ^bb1,
|
||||
/// 42: ^bb2
|
||||
/// ]
|
||||
/// ^bb1:
|
||||
/// switch %flag : i32, [
|
||||
/// default: ^bb3,
|
||||
/// 42: ^bb4,
|
||||
/// 43: ^bb5
|
||||
/// ]
|
||||
/// ->
|
||||
/// switch %flag : i32, [
|
||||
/// default: ^bb1,
|
||||
/// 42: ^bb2,
|
||||
/// ]
|
||||
/// ^bb1:
|
||||
/// switch %flag : i32, [
|
||||
/// default: ^bb3,
|
||||
/// 43: ^bb5
|
||||
/// ]
|
||||
static LogicalResult
|
||||
simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op,
|
||||
PatternRewriter &rewriter) {
|
||||
// Check that we have a single distinct predecessor.
|
||||
Block *currentBlock = op->getBlock();
|
||||
Block *predecessor = currentBlock->getSinglePredecessor();
|
||||
if (!predecessor)
|
||||
return failure();
|
||||
|
||||
// Check that the predecessor terminates with a switch branch to this block
|
||||
// and that it branches on the same condition and that this branch is the
|
||||
// default destination.
|
||||
auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
|
||||
if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
|
||||
predSwitch.getDefaultDestination() != currentBlock)
|
||||
return failure();
|
||||
|
||||
// Delete case values that are not possible here.
|
||||
DenseSet<APInt> caseValuesToRemove;
|
||||
auto predDests = predSwitch.getCaseDestinations();
|
||||
auto predCaseValues = predSwitch.getCaseValues();
|
||||
for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i)
|
||||
if (currentBlock != predDests[i])
|
||||
caseValuesToRemove.insert(predCaseValues->getValues<APInt>()[i]);
|
||||
|
||||
SmallVector<Block *> newCaseDestinations;
|
||||
SmallVector<ValueRange> newCaseOperands;
|
||||
SmallVector<APInt> newCaseValues;
|
||||
bool requiresChange = false;
|
||||
|
||||
auto caseValues = op.getCaseValues();
|
||||
auto caseDests = op.getCaseDestinations();
|
||||
for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
|
||||
if (caseValuesToRemove.contains(it.value())) {
|
||||
requiresChange = true;
|
||||
continue;
|
||||
}
|
||||
newCaseDestinations.push_back(caseDests[it.index()]);
|
||||
newCaseOperands.push_back(op.getCaseOperands(it.index()));
|
||||
newCaseValues.push_back(it.value());
|
||||
}
|
||||
|
||||
if (!requiresChange)
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<SwitchOp>(
|
||||
op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
|
||||
newCaseValues, newCaseDestinations, newCaseOperands);
|
||||
return success();
|
||||
}
|
||||
|
||||
void SwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
MLIRContext *context) {
|
||||
results.add(&simplifySwitchWithOnlyDefault)
|
||||
.add(&dropSwitchCasesThatMatchDefault)
|
||||
.add(&simplifyConstSwitchValue)
|
||||
.add(&simplifyPassThroughSwitch)
|
||||
.add(&simplifySwitchFromSwitchOnSameCondition)
|
||||
.add(&simplifySwitchFromDefaultSwitchOnSameCondition);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TableGen'd op method definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
|
||||
#include "mlir/Dialect/EmitC/IR/EmitC.h"
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
|
@ -41,6 +42,7 @@ void registerToCppTranslation() {
|
|||
[](DialectRegistry ®istry) {
|
||||
// clang-format off
|
||||
registry.insert<arith::ArithmeticDialect,
|
||||
cf::ControlFlowDialect,
|
||||
emitc::EmitCDialect,
|
||||
math::MathDialect,
|
||||
StandardOpsDialect,
|
||||
|
|
|
@ -6,8 +6,7 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
#include "mlir/Dialect/EmitC/IR/EmitC.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
|
@ -23,6 +22,7 @@
|
|||
#include "llvm/ADT/TypeSwitch.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include <utility>
|
||||
|
||||
#define DEBUG_TYPE "translate-to-cpp"
|
||||
|
||||
|
@ -237,7 +237,8 @@ static LogicalResult printOperation(CppEmitter &emitter,
|
|||
return printConstantOp(emitter, operation, value);
|
||||
}
|
||||
|
||||
static LogicalResult printOperation(CppEmitter &emitter, BranchOp branchOp) {
|
||||
static LogicalResult printOperation(CppEmitter &emitter,
|
||||
cf::BranchOp branchOp) {
|
||||
raw_ostream &os = emitter.ostream();
|
||||
Block &successor = *branchOp.getSuccessor();
|
||||
|
||||
|
@ -257,7 +258,7 @@ static LogicalResult printOperation(CppEmitter &emitter, BranchOp branchOp) {
|
|||
}
|
||||
|
||||
static LogicalResult printOperation(CppEmitter &emitter,
|
||||
CondBranchOp condBranchOp) {
|
||||
cf::CondBranchOp condBranchOp) {
|
||||
raw_indented_ostream &os = emitter.ostream();
|
||||
Block &trueSuccessor = *condBranchOp.getTrueDest();
|
||||
Block &falseSuccessor = *condBranchOp.getFalseDest();
|
||||
|
@ -637,11 +638,12 @@ static LogicalResult printOperation(CppEmitter &emitter, FuncOp functionOp) {
|
|||
return failure();
|
||||
}
|
||||
for (Operation &op : block.getOperations()) {
|
||||
// When generating code for an scf.if or std.cond_br op no semicolon needs
|
||||
// When generating code for an scf.if or cf.cond_br op no semicolon needs
|
||||
// to be printed after the closing brace.
|
||||
// When generating code for an scf.for op, printing a trailing semicolon
|
||||
// is handled within the printOperation function.
|
||||
bool trailingSemicolon = !isa<scf::IfOp, scf::ForOp, CondBranchOp>(op);
|
||||
bool trailingSemicolon =
|
||||
!isa<scf::IfOp, scf::ForOp, cf::CondBranchOp>(op);
|
||||
|
||||
if (failed(emitter.emitOperation(
|
||||
op, /*trailingSemicolon=*/trailingSemicolon)))
|
||||
|
@ -907,8 +909,8 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
|
|||
.Case<scf::ForOp, scf::IfOp, scf::YieldOp>(
|
||||
[&](auto op) { return printOperation(*this, op); })
|
||||
// Standard ops.
|
||||
.Case<BranchOp, mlir::CallOp, CondBranchOp, mlir::ConstantOp, FuncOp,
|
||||
ModuleOp, ReturnOp>(
|
||||
.Case<cf::BranchOp, mlir::CallOp, cf::CondBranchOp, mlir::ConstantOp,
|
||||
FuncOp, ModuleOp, ReturnOp>(
|
||||
[&](auto op) { return printOperation(*this, op); })
|
||||
// Arithmetic ops.
|
||||
.Case<arith::ConstantOp>(
|
||||
|
|
|
@ -52,10 +52,10 @@ func @control_flow(%arg: memref<2xf32>, %cond: i1) attributes {test.ptr = "func"
|
|||
%1 = memref.alloca() {test.ptr = "alloca_2"} : memref<8x64xf32>
|
||||
%2 = memref.alloc() {test.ptr = "alloc_1"} : memref<8x64xf32>
|
||||
|
||||
cond_br %cond, ^bb1(%0 : memref<8x64xf32>), ^bb2(%0 : memref<8x64xf32>)
|
||||
cf.cond_br %cond, ^bb1(%0 : memref<8x64xf32>), ^bb2(%0 : memref<8x64xf32>)
|
||||
|
||||
^bb1(%arg1: memref<8x64xf32>):
|
||||
br ^bb2(%arg1 : memref<8x64xf32>)
|
||||
cf.br ^bb2(%arg1 : memref<8x64xf32>)
|
||||
|
||||
^bb2(%arg2: memref<8x64xf32>):
|
||||
return
|
||||
|
@ -85,10 +85,10 @@ func @control_flow_merge(%arg: memref<2xf32>, %cond: i1) attributes {test.ptr =
|
|||
%1 = memref.alloca() {test.ptr = "alloca_2"} : memref<8x64xf32>
|
||||
%2 = memref.alloc() {test.ptr = "alloc_1"} : memref<8x64xf32>
|
||||
|
||||
cond_br %cond, ^bb1(%0 : memref<8x64xf32>), ^bb2(%2 : memref<8x64xf32>)
|
||||
cf.cond_br %cond, ^bb1(%0 : memref<8x64xf32>), ^bb2(%2 : memref<8x64xf32>)
|
||||
|
||||
^bb1(%arg1: memref<8x64xf32>):
|
||||
br ^bb2(%arg1 : memref<8x64xf32>)
|
||||
cf.br ^bb2(%arg1 : memref<8x64xf32>)
|
||||
|
||||
^bb2(%arg2: memref<8x64xf32>):
|
||||
return
|
||||
|
|
|
@ -2,11 +2,11 @@
|
|||
|
||||
// CHECK-LABEL: Testing : func_condBranch
|
||||
func @func_condBranch(%cond : i1) {
|
||||
cond_br %cond, ^bb1, ^bb2
|
||||
cf.cond_br %cond, ^bb1, ^bb2
|
||||
^bb1:
|
||||
br ^exit
|
||||
cf.br ^exit
|
||||
^bb2:
|
||||
br ^exit
|
||||
cf.br ^exit
|
||||
^exit:
|
||||
return
|
||||
}
|
||||
|
@ -49,14 +49,14 @@ func @func_condBranch(%cond : i1) {
|
|||
|
||||
// CHECK-LABEL: Testing : func_loop
|
||||
func @func_loop(%arg0 : i32, %arg1 : i32) {
|
||||
br ^loopHeader(%arg0 : i32)
|
||||
cf.br ^loopHeader(%arg0 : i32)
|
||||
^loopHeader(%counter : i32):
|
||||
%lessThan = arith.cmpi slt, %counter, %arg1 : i32
|
||||
cond_br %lessThan, ^loopBody, ^exit
|
||||
cf.cond_br %lessThan, ^loopBody, ^exit
|
||||
^loopBody:
|
||||
%const0 = arith.constant 1 : i32
|
||||
%inc = arith.addi %counter, %const0 : i32
|
||||
br ^loopHeader(%inc : i32)
|
||||
cf.br ^loopHeader(%inc : i32)
|
||||
^exit:
|
||||
return
|
||||
}
|
||||
|
@ -153,17 +153,17 @@ func @func_loop_nested_region(
|
|||
%arg2 : index,
|
||||
%arg3 : index,
|
||||
%arg4 : index) {
|
||||
br ^loopHeader(%arg0 : i32)
|
||||
cf.br ^loopHeader(%arg0 : i32)
|
||||
^loopHeader(%counter : i32):
|
||||
%lessThan = arith.cmpi slt, %counter, %arg1 : i32
|
||||
cond_br %lessThan, ^loopBody, ^exit
|
||||
cf.cond_br %lessThan, ^loopBody, ^exit
|
||||
^loopBody:
|
||||
%const0 = arith.constant 1 : i32
|
||||
%inc = arith.addi %counter, %const0 : i32
|
||||
scf.for %arg5 = %arg2 to %arg3 step %arg4 {
|
||||
scf.for %arg6 = %arg2 to %arg3 step %arg4 { }
|
||||
}
|
||||
br ^loopHeader(%inc : i32)
|
||||
cf.br ^loopHeader(%inc : i32)
|
||||
^exit:
|
||||
return
|
||||
}
|
||||
|
|
|
@ -19,7 +19,7 @@ func @func_simpleBranch(%arg0: i32, %arg1 : i32) -> i32 {
|
|||
// CHECK-NEXT: LiveOut: arg0@0 arg1@0
|
||||
// CHECK-NEXT: BeginLiveness
|
||||
// CHECK-NEXT: EndLiveness
|
||||
br ^exit
|
||||
cf.br ^exit
|
||||
^exit:
|
||||
// CHECK: Block: 1
|
||||
// CHECK-NEXT: LiveIn: arg0@0 arg1@0
|
||||
|
@ -42,17 +42,17 @@ func @func_condBranch(%cond : i1, %arg1: i32, %arg2 : i32) -> i32 {
|
|||
// CHECK-NEXT: LiveOut: arg1@0 arg2@0
|
||||
// CHECK-NEXT: BeginLiveness
|
||||
// CHECK-NEXT: EndLiveness
|
||||
cond_br %cond, ^bb1, ^bb2
|
||||
cf.cond_br %cond, ^bb1, ^bb2
|
||||
^bb1:
|
||||
// CHECK: Block: 1
|
||||
// CHECK-NEXT: LiveIn: arg1@0 arg2@0
|
||||
// CHECK-NEXT: LiveOut: arg1@0 arg2@0
|
||||
br ^exit
|
||||
cf.br ^exit
|
||||
^bb2:
|
||||
// CHECK: Block: 2
|
||||
// CHECK-NEXT: LiveIn: arg1@0 arg2@0
|
||||
// CHECK-NEXT: LiveOut: arg1@0 arg2@0
|
||||
br ^exit
|
||||
cf.br ^exit
|
||||
^exit:
|
||||
// CHECK: Block: 3
|
||||
// CHECK-NEXT: LiveIn: arg1@0 arg2@0
|
||||
|
@ -74,7 +74,7 @@ func @func_loop(%arg0 : i32, %arg1 : i32) -> i32 {
|
|||
// CHECK-NEXT: LiveIn:{{ *$}}
|
||||
// CHECK-NEXT: LiveOut: arg1@0
|
||||
%const0 = arith.constant 0 : i32
|
||||
br ^loopHeader(%const0, %arg0 : i32, i32)
|
||||
cf.br ^loopHeader(%const0, %arg0 : i32, i32)
|
||||
^loopHeader(%counter : i32, %i : i32):
|
||||
// CHECK: Block: 1
|
||||
// CHECK-NEXT: LiveIn: arg1@0
|
||||
|
@ -82,10 +82,10 @@ func @func_loop(%arg0 : i32, %arg1 : i32) -> i32 {
|
|||
// CHECK-NEXT: BeginLiveness
|
||||
// CHECK-NEXT: val_5
|
||||
// CHECK-NEXT: %2 = arith.cmpi
|
||||
// CHECK-NEXT: cond_br
|
||||
// CHECK-NEXT: cf.cond_br
|
||||
// CHECK-NEXT: EndLiveness
|
||||
%lessThan = arith.cmpi slt, %counter, %arg1 : i32
|
||||
cond_br %lessThan, ^loopBody(%i : i32), ^exit(%i : i32)
|
||||
cf.cond_br %lessThan, ^loopBody(%i : i32), ^exit(%i : i32)
|
||||
^loopBody(%val : i32):
|
||||
// CHECK: Block: 2
|
||||
// CHECK-NEXT: LiveIn: arg1@0 arg0@1
|
||||
|
@ -98,12 +98,12 @@ func @func_loop(%arg0 : i32, %arg1 : i32) -> i32 {
|
|||
// CHECK-NEXT: val_8
|
||||
// CHECK-NEXT: %4 = arith.addi
|
||||
// CHECK-NEXT: %5 = arith.addi
|
||||
// CHECK-NEXT: br
|
||||
// CHECK-NEXT: cf.br
|
||||
// CHECK: EndLiveness
|
||||
%const1 = arith.constant 1 : i32
|
||||
%inc = arith.addi %val, %const1 : i32
|
||||
%inc2 = arith.addi %counter, %const1 : i32
|
||||
br ^loopHeader(%inc, %inc2 : i32, i32)
|
||||
cf.br ^loopHeader(%inc, %inc2 : i32, i32)
|
||||
^exit(%sum : i32):
|
||||
// CHECK: Block: 3
|
||||
// CHECK-NEXT: LiveIn: arg1@0
|
||||
|
@ -147,14 +147,14 @@ func @func_ranges(%cond : i1, %arg1 : i32, %arg2 : i32, %arg3 : i32) -> i32 {
|
|||
// CHECK-NEXT: val_9
|
||||
// CHECK-NEXT: %4 = arith.muli
|
||||
// CHECK-NEXT: %5 = arith.addi
|
||||
// CHECK-NEXT: cond_br
|
||||
// CHECK-NEXT: cf.cond_br
|
||||
// CHECK-NEXT: %c
|
||||
// CHECK-NEXT: %6 = arith.muli
|
||||
// CHECK-NEXT: %7 = arith.muli
|
||||
// CHECK-NEXT: %8 = arith.addi
|
||||
// CHECK-NEXT: val_10
|
||||
// CHECK-NEXT: %5 = arith.addi
|
||||
// CHECK-NEXT: cond_br
|
||||
// CHECK-NEXT: cf.cond_br
|
||||
// CHECK-NEXT: %7
|
||||
// CHECK: EndLiveness
|
||||
%0 = arith.addi %arg1, %arg2 : i32
|
||||
|
@ -164,7 +164,7 @@ func @func_ranges(%cond : i1, %arg1 : i32, %arg2 : i32, %arg3 : i32) -> i32 {
|
|||
%3 = arith.muli %0, %1 : i32
|
||||
%4 = arith.muli %3, %2 : i32
|
||||
%5 = arith.addi %4, %const1 : i32
|
||||
cond_br %cond, ^bb1, ^bb2
|
||||
cf.cond_br %cond, ^bb1, ^bb2
|
||||
|
||||
^bb1:
|
||||
// CHECK: Block: 1
|
||||
|
@ -172,7 +172,7 @@ func @func_ranges(%cond : i1, %arg1 : i32, %arg2 : i32, %arg3 : i32) -> i32 {
|
|||
// CHECK-NEXT: LiveOut: arg2@0
|
||||
%const4 = arith.constant 4 : i32
|
||||
%6 = arith.muli %4, %const4 : i32
|
||||
br ^exit(%6 : i32)
|
||||
cf.br ^exit(%6 : i32)
|
||||
|
||||
^bb2:
|
||||
// CHECK: Block: 2
|
||||
|
@ -180,7 +180,7 @@ func @func_ranges(%cond : i1, %arg1 : i32, %arg2 : i32, %arg3 : i32) -> i32 {
|
|||
// CHECK-NEXT: LiveOut: arg2@0
|
||||
%7 = arith.muli %4, %5 : i32
|
||||
%8 = arith.addi %4, %arg2 : i32
|
||||
br ^exit(%8 : i32)
|
||||
cf.br ^exit(%8 : i32)
|
||||
|
||||
^exit(%sum : i32):
|
||||
// CHECK: Block: 3
|
||||
|
@ -284,7 +284,7 @@ func @nested_region3(
|
|||
// CHECK-NEXT: %0 = arith.addi
|
||||
// CHECK-NEXT: %1 = arith.addi
|
||||
// CHECK-NEXT: scf.for
|
||||
// CHECK: // br ^bb1
|
||||
// CHECK: // cf.br ^bb1
|
||||
// CHECK-NEXT: %2 = arith.addi
|
||||
// CHECK-NEXT: scf.for
|
||||
// CHECK: // %2 = arith.addi
|
||||
|
@ -301,7 +301,7 @@ func @nested_region3(
|
|||
%2 = arith.addi %0, %arg5 : i32
|
||||
memref.store %2, %buffer[] : memref<i32>
|
||||
}
|
||||
br ^exit
|
||||
cf.br ^exit
|
||||
|
||||
^exit:
|
||||
// CHECK: Block: 2
|
||||
|
|
|
@ -1531,10 +1531,10 @@ int registerOnlyStd() {
|
|||
fprintf(stderr, "@registration\n");
|
||||
// CHECK-LABEL: @registration
|
||||
|
||||
// CHECK: std.cond_br is_registered: 1
|
||||
fprintf(stderr, "std.cond_br is_registered: %d\n",
|
||||
// CHECK: cf.cond_br is_registered: 1
|
||||
fprintf(stderr, "cf.cond_br is_registered: %d\n",
|
||||
mlirContextIsRegisteredOperation(
|
||||
ctx, mlirStringRefCreateFromCString("std.cond_br")));
|
||||
ctx, mlirStringRefCreateFromCString("cf.cond_br")));
|
||||
|
||||
// CHECK: std.not_existing_op is_registered: 0
|
||||
fprintf(stderr, "std.not_existing_op is_registered: %d\n",
|
||||
|
|
|
@ -27,7 +27,7 @@ func @execute_no_async_args(%arg0: f32, %arg1: memref<1xf32>) {
|
|||
// CHECK: %[[IS_ERROR:.*]] = call @mlirAsyncRuntimeIsTokenError(%[[TOKEN]])
|
||||
// CHECK: %[[TRUE:.*]] = arith.constant true
|
||||
// CHECK: %[[NOT_ERROR:.*]] = arith.xori %[[IS_ERROR]], %[[TRUE]] : i1
|
||||
// CHECK: assert %[[NOT_ERROR]]
|
||||
// CHECK: cf.assert %[[NOT_ERROR]]
|
||||
// CHECK-NEXT: return
|
||||
async.await %token : !async.token
|
||||
return
|
||||
|
@ -90,7 +90,7 @@ func @nested_async_execute(%arg0: f32, %arg1: f32, %arg2: memref<1xf32>) {
|
|||
// CHECK: %[[IS_ERROR:.*]] = call @mlirAsyncRuntimeIsTokenError(%[[TOKEN]])
|
||||
// CHECK: %[[TRUE:.*]] = arith.constant true
|
||||
// CHECK: %[[NOT_ERROR:.*]] = arith.xori %[[IS_ERROR]], %[[TRUE]] : i1
|
||||
// CHECK: assert %[[NOT_ERROR]]
|
||||
// CHECK: cf.assert %[[NOT_ERROR]]
|
||||
async.await %token0 : !async.token
|
||||
return
|
||||
}
|
||||
|
|
|
@ -0,0 +1,41 @@
|
|||
// RUN: mlir-opt -split-input-file -convert-std-to-spirv -verify-diagnostics %s | FileCheck %s
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// cf.br, cf.cond_br
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
module attributes {
|
||||
spv.target_env = #spv.target_env<#spv.vce<v1.0, [], []>, {}>
|
||||
} {
|
||||
|
||||
// CHECK-LABEL: func @simple_loop
|
||||
func @simple_loop(index, index, index) {
|
||||
^bb0(%begin : index, %end : index, %step : index):
|
||||
// CHECK-NEXT: spv.Branch ^bb1
|
||||
cf.br ^bb1
|
||||
|
||||
// CHECK-NEXT: ^bb1: // pred: ^bb0
|
||||
// CHECK-NEXT: spv.Branch ^bb2({{.*}} : i32)
|
||||
^bb1: // pred: ^bb0
|
||||
cf.br ^bb2(%begin : index)
|
||||
|
||||
// CHECK: ^bb2({{.*}}: i32): // 2 preds: ^bb1, ^bb3
|
||||
// CHECK-NEXT: {{.*}} = spv.SLessThan {{.*}}, {{.*}} : i32
|
||||
// CHECK-NEXT: spv.BranchConditional {{.*}}, ^bb3, ^bb4
|
||||
^bb2(%0: index): // 2 preds: ^bb1, ^bb3
|
||||
%1 = arith.cmpi slt, %0, %end : index
|
||||
cf.cond_br %1, ^bb3, ^bb4
|
||||
|
||||
// CHECK: ^bb3: // pred: ^bb2
|
||||
// CHECK-NEXT: {{.*}} = spv.IAdd {{.*}}, {{.*}} : i32
|
||||
// CHECK-NEXT: spv.Branch ^bb2({{.*}} : i32)
|
||||
^bb3: // pred: ^bb2
|
||||
%2 = arith.addi %0, %step : index
|
||||
cf.br ^bb2(%2 : index)
|
||||
|
||||
// CHECK: ^bb4: // pred: ^bb2
|
||||
^bb4: // pred: ^bb2
|
||||
return
|
||||
}
|
||||
|
||||
}
|
|
@ -168,16 +168,16 @@ gpu.module @test_module {
|
|||
%c128 = arith.constant 128 : index
|
||||
%c32 = arith.constant 32 : index
|
||||
%0 = gpu.subgroup_mma_load_matrix %arg2[%c0, %c0] {leadDimension = 128 : index} : memref<128x128xf16> -> !gpu.mma_matrix<16x16xf16, "COp">
|
||||
br ^bb1(%c0, %0 : index, !gpu.mma_matrix<16x16xf16, "COp">)
|
||||
cf.br ^bb1(%c0, %0 : index, !gpu.mma_matrix<16x16xf16, "COp">)
|
||||
^bb1(%1: index, %2: !gpu.mma_matrix<16x16xf16, "COp">): // 2 preds: ^bb0, ^bb2
|
||||
%3 = arith.cmpi slt, %1, %c128 : index
|
||||
cond_br %3, ^bb2, ^bb3
|
||||
cf.cond_br %3, ^bb2, ^bb3
|
||||
^bb2: // pred: ^bb1
|
||||
%4 = gpu.subgroup_mma_load_matrix %arg0[%c0, %1] {leadDimension = 128 : index} : memref<128x128xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
|
||||
%5 = gpu.subgroup_mma_load_matrix %arg1[%1, %c0] {leadDimension = 128 : index} : memref<128x128xf16> -> !gpu.mma_matrix<16x16xf16, "BOp">
|
||||
%6 = gpu.subgroup_mma_compute %4, %5, %2 : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
|
||||
%7 = arith.addi %1, %c32 : index
|
||||
br ^bb1(%7, %6 : index, !gpu.mma_matrix<16x16xf16, "COp">)
|
||||
cf.br ^bb1(%7, %6 : index, !gpu.mma_matrix<16x16xf16, "COp">)
|
||||
^bb3: // pred: ^bb1
|
||||
gpu.subgroup_mma_store_matrix %2, %arg2[%c0, %c0] {leadDimension = 128 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<128x128xf16>
|
||||
return
|
||||
|
|
|
@ -22,17 +22,17 @@ func @branch_loop() {
|
|||
// CHECK: omp.parallel
|
||||
omp.parallel {
|
||||
// CHECK-NEXT: llvm.br ^[[BB1:.*]](%{{[0-9]+}}, %{{[0-9]+}} : i64, i64
|
||||
br ^bb1(%start, %end : index, index)
|
||||
cf.br ^bb1(%start, %end : index, index)
|
||||
// CHECK-NEXT: ^[[BB1]](%[[ARG1:[0-9]+]]: i64, %[[ARG2:[0-9]+]]: i64):{{.*}}
|
||||
^bb1(%0: index, %1: index):
|
||||
// CHECK-NEXT: %[[CMP:[0-9]+]] = llvm.icmp "slt" %[[ARG1]], %[[ARG2]] : i64
|
||||
%2 = arith.cmpi slt, %0, %1 : index
|
||||
// CHECK-NEXT: llvm.cond_br %[[CMP]], ^[[BB2:.*]](%{{[0-9]+}}, %{{[0-9]+}} : i64, i64), ^[[BB3:.*]]
|
||||
cond_br %2, ^bb2(%end, %end : index, index), ^bb3
|
||||
cf.cond_br %2, ^bb2(%end, %end : index, index), ^bb3
|
||||
// CHECK-NEXT: ^[[BB2]](%[[ARG3:[0-9]+]]: i64, %[[ARG4:[0-9]+]]: i64):
|
||||
^bb2(%3: index, %4: index):
|
||||
// CHECK-NEXT: llvm.br ^[[BB1]](%[[ARG3]], %[[ARG4]] : i64, i64)
|
||||
br ^bb1(%3, %4 : index, index)
|
||||
cf.br ^bb1(%3, %4 : index, index)
|
||||
// CHECK-NEXT: ^[[BB3]]:
|
||||
^bb3:
|
||||
omp.flush
|
||||
|
|
|
@ -1,14 +1,14 @@
|
|||
// RUN: mlir-opt -allow-unregistered-dialect -convert-scf-to-std %s | FileCheck %s
|
||||
// RUN: mlir-opt -allow-unregistered-dialect -convert-scf-to-cf %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @simple_std_for_loop(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index) {
|
||||
// CHECK-NEXT: br ^bb1(%{{.*}} : index)
|
||||
// CHECK-NEXT: cf.br ^bb1(%{{.*}} : index)
|
||||
// CHECK-NEXT: ^bb1(%{{.*}}: index): // 2 preds: ^bb0, ^bb2
|
||||
// CHECK-NEXT: %{{.*}} = arith.cmpi slt, %{{.*}}, %{{.*}} : index
|
||||
// CHECK-NEXT: cond_br %{{.*}}, ^bb2, ^bb3
|
||||
// CHECK-NEXT: cf.cond_br %{{.*}}, ^bb2, ^bb3
|
||||
// CHECK-NEXT: ^bb2: // pred: ^bb1
|
||||
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index
|
||||
// CHECK-NEXT: %[[iv:.*]] = arith.addi %{{.*}}, %{{.*}} : index
|
||||
// CHECK-NEXT: br ^bb1(%[[iv]] : index)
|
||||
// CHECK-NEXT: cf.br ^bb1(%[[iv]] : index)
|
||||
// CHECK-NEXT: ^bb3: // pred: ^bb1
|
||||
// CHECK-NEXT: return
|
||||
func @simple_std_for_loop(%arg0 : index, %arg1 : index, %arg2 : index) {
|
||||
|
@ -19,23 +19,23 @@ func @simple_std_for_loop(%arg0 : index, %arg1 : index, %arg2 : index) {
|
|||
}
|
||||
|
||||
// CHECK-LABEL: func @simple_std_2_for_loops(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index) {
|
||||
// CHECK-NEXT: br ^bb1(%{{.*}} : index)
|
||||
// CHECK-NEXT: cf.br ^bb1(%{{.*}} : index)
|
||||
// CHECK-NEXT: ^bb1(%[[ub0:.*]]: index): // 2 preds: ^bb0, ^bb5
|
||||
// CHECK-NEXT: %[[cond0:.*]] = arith.cmpi slt, %[[ub0]], %{{.*}} : index
|
||||
// CHECK-NEXT: cond_br %[[cond0]], ^bb2, ^bb6
|
||||
// CHECK-NEXT: cf.cond_br %[[cond0]], ^bb2, ^bb6
|
||||
// CHECK-NEXT: ^bb2: // pred: ^bb1
|
||||
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index
|
||||
// CHECK-NEXT: br ^bb3(%{{.*}} : index)
|
||||
// CHECK-NEXT: cf.br ^bb3(%{{.*}} : index)
|
||||
// CHECK-NEXT: ^bb3(%[[ub1:.*]]: index): // 2 preds: ^bb2, ^bb4
|
||||
// CHECK-NEXT: %[[cond1:.*]] = arith.cmpi slt, %{{.*}}, %{{.*}} : index
|
||||
// CHECK-NEXT: cond_br %[[cond1]], ^bb4, ^bb5
|
||||
// CHECK-NEXT: cf.cond_br %[[cond1]], ^bb4, ^bb5
|
||||
// CHECK-NEXT: ^bb4: // pred: ^bb3
|
||||
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index
|
||||
// CHECK-NEXT: %[[iv1:.*]] = arith.addi %{{.*}}, %{{.*}} : index
|
||||
// CHECK-NEXT: br ^bb3(%[[iv1]] : index)
|
||||
// CHECK-NEXT: cf.br ^bb3(%[[iv1]] : index)
|
||||
// CHECK-NEXT: ^bb5: // pred: ^bb3
|
||||
// CHECK-NEXT: %[[iv0:.*]] = arith.addi %{{.*}}, %{{.*}} : index
|
||||
// CHECK-NEXT: br ^bb1(%[[iv0]] : index)
|
||||
// CHECK-NEXT: cf.br ^bb1(%[[iv0]] : index)
|
||||
// CHECK-NEXT: ^bb6: // pred: ^bb1
|
||||
// CHECK-NEXT: return
|
||||
func @simple_std_2_for_loops(%arg0 : index, %arg1 : index, %arg2 : index) {
|
||||
|
@ -49,10 +49,10 @@ func @simple_std_2_for_loops(%arg0 : index, %arg1 : index, %arg2 : index) {
|
|||
}
|
||||
|
||||
// CHECK-LABEL: func @simple_std_if(%{{.*}}: i1) {
|
||||
// CHECK-NEXT: cond_br %{{.*}}, ^bb1, ^bb2
|
||||
// CHECK-NEXT: cf.cond_br %{{.*}}, ^bb1, ^bb2
|
||||
// CHECK-NEXT: ^bb1: // pred: ^bb0
|
||||
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index
|
||||
// CHECK-NEXT: br ^bb2
|
||||
// CHECK-NEXT: cf.br ^bb2
|
||||
// CHECK-NEXT: ^bb2: // 2 preds: ^bb0, ^bb1
|
||||
// CHECK-NEXT: return
|
||||
func @simple_std_if(%arg0: i1) {
|
||||
|
@ -63,13 +63,13 @@ func @simple_std_if(%arg0: i1) {
|
|||
}
|
||||
|
||||
// CHECK-LABEL: func @simple_std_if_else(%{{.*}}: i1) {
|
||||
// CHECK-NEXT: cond_br %{{.*}}, ^bb1, ^bb2
|
||||
// CHECK-NEXT: cf.cond_br %{{.*}}, ^bb1, ^bb2
|
||||
// CHECK-NEXT: ^bb1: // pred: ^bb0
|
||||
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index
|
||||
// CHECK-NEXT: br ^bb3
|
||||
// CHECK-NEXT: cf.br ^bb3
|
||||
// CHECK-NEXT: ^bb2: // pred: ^bb0
|
||||
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index
|
||||
// CHECK-NEXT: br ^bb3
|
||||
// CHECK-NEXT: cf.br ^bb3
|
||||
// CHECK-NEXT: ^bb3: // 2 preds: ^bb1, ^bb2
|
||||
// CHECK-NEXT: return
|
||||
func @simple_std_if_else(%arg0: i1) {
|
||||
|
@ -82,18 +82,18 @@ func @simple_std_if_else(%arg0: i1) {
|
|||
}
|
||||
|
||||
// CHECK-LABEL: func @simple_std_2_ifs(%{{.*}}: i1) {
|
||||
// CHECK-NEXT: cond_br %{{.*}}, ^bb1, ^bb5
|
||||
// CHECK-NEXT: cf.cond_br %{{.*}}, ^bb1, ^bb5
|
||||
// CHECK-NEXT: ^bb1: // pred: ^bb0
|
||||
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index
|
||||
// CHECK-NEXT: cond_br %{{.*}}, ^bb2, ^bb3
|
||||
// CHECK-NEXT: cf.cond_br %{{.*}}, ^bb2, ^bb3
|
||||
// CHECK-NEXT: ^bb2: // pred: ^bb1
|
||||
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index
|
||||
// CHECK-NEXT: br ^bb4
|
||||
// CHECK-NEXT: cf.br ^bb4
|
||||
// CHECK-NEXT: ^bb3: // pred: ^bb1
|
||||
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index
|
||||
// CHECK-NEXT: br ^bb4
|
||||
// CHECK-NEXT: cf.br ^bb4
|
||||
// CHECK-NEXT: ^bb4: // 2 preds: ^bb2, ^bb3
|
||||
// CHECK-NEXT: br ^bb5
|
||||
// CHECK-NEXT: cf.br ^bb5
|
||||
// CHECK-NEXT: ^bb5: // 2 preds: ^bb0, ^bb4
|
||||
// CHECK-NEXT: return
|
||||
func @simple_std_2_ifs(%arg0: i1) {
|
||||
|
@ -109,27 +109,27 @@ func @simple_std_2_ifs(%arg0: i1) {
|
|||
}
|
||||
|
||||
// CHECK-LABEL: func @simple_std_for_loop_with_2_ifs(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: i1) {
|
||||
// CHECK-NEXT: br ^bb1(%{{.*}} : index)
|
||||
// CHECK-NEXT: cf.br ^bb1(%{{.*}} : index)
|
||||
// CHECK-NEXT: ^bb1(%{{.*}}: index): // 2 preds: ^bb0, ^bb7
|
||||
// CHECK-NEXT: %{{.*}} = arith.cmpi slt, %{{.*}}, %{{.*}} : index
|
||||
// CHECK-NEXT: cond_br %{{.*}}, ^bb2, ^bb8
|
||||
// CHECK-NEXT: cf.cond_br %{{.*}}, ^bb2, ^bb8
|
||||
// CHECK-NEXT: ^bb2: // pred: ^bb1
|
||||
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index
|
||||
// CHECK-NEXT: cond_br %{{.*}}, ^bb3, ^bb7
|
||||
// CHECK-NEXT: cf.cond_br %{{.*}}, ^bb3, ^bb7
|
||||
// CHECK-NEXT: ^bb3: // pred: ^bb2
|
||||
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index
|
||||
// CHECK-NEXT: cond_br %{{.*}}, ^bb4, ^bb5
|
||||
// CHECK-NEXT: cf.cond_br %{{.*}}, ^bb4, ^bb5
|
||||
// CHECK-NEXT: ^bb4: // pred: ^bb3
|
||||
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index
|
||||
// CHECK-NEXT: br ^bb6
|
||||
// CHECK-NEXT: cf.br ^bb6
|
||||
// CHECK-NEXT: ^bb5: // pred: ^bb3
|
||||
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index
|
||||
// CHECK-NEXT: br ^bb6
|
||||
// CHECK-NEXT: cf.br ^bb6
|
||||
// CHECK-NEXT: ^bb6: // 2 preds: ^bb4, ^bb5
|
||||
// CHECK-NEXT: br ^bb7
|
||||
// CHECK-NEXT: cf.br ^bb7
|
||||
// CHECK-NEXT: ^bb7: // 2 preds: ^bb2, ^bb6
|
||||
// CHECK-NEXT: %[[iv0:.*]] = arith.addi %{{.*}}, %{{.*}} : index
|
||||
// CHECK-NEXT: br ^bb1(%[[iv0]] : index)
|
||||
// CHECK-NEXT: cf.br ^bb1(%[[iv0]] : index)
|
||||
// CHECK-NEXT: ^bb8: // pred: ^bb1
|
||||
// CHECK-NEXT: return
|
||||
// CHECK-NEXT: }
|
||||
|
@ -150,12 +150,12 @@ func @simple_std_for_loop_with_2_ifs(%arg0 : index, %arg1 : index, %arg2 : index
|
|||
|
||||
// CHECK-LABEL: func @simple_if_yield
|
||||
func @simple_if_yield(%arg0: i1) -> (i1, i1) {
|
||||
// CHECK: cond_br %{{.*}}, ^[[then:.*]], ^[[else:.*]]
|
||||
// CHECK: cf.cond_br %{{.*}}, ^[[then:.*]], ^[[else:.*]]
|
||||
%0:2 = scf.if %arg0 -> (i1, i1) {
|
||||
// CHECK: ^[[then]]:
|
||||
// CHECK: %[[v0:.*]] = arith.constant false
|
||||
// CHECK: %[[v1:.*]] = arith.constant true
|
||||
// CHECK: br ^[[dom:.*]](%[[v0]], %[[v1]] : i1, i1)
|
||||
// CHECK: cf.br ^[[dom:.*]](%[[v0]], %[[v1]] : i1, i1)
|
||||
%c0 = arith.constant false
|
||||
%c1 = arith.constant true
|
||||
scf.yield %c0, %c1 : i1, i1
|
||||
|
@ -163,13 +163,13 @@ func @simple_if_yield(%arg0: i1) -> (i1, i1) {
|
|||
// CHECK: ^[[else]]:
|
||||
// CHECK: %[[v2:.*]] = arith.constant false
|
||||
// CHECK: %[[v3:.*]] = arith.constant true
|
||||
// CHECK: br ^[[dom]](%[[v3]], %[[v2]] : i1, i1)
|
||||
// CHECK: cf.br ^[[dom]](%[[v3]], %[[v2]] : i1, i1)
|
||||
%c0 = arith.constant false
|
||||
%c1 = arith.constant true
|
||||
scf.yield %c1, %c0 : i1, i1
|
||||
}
|
||||
// CHECK: ^[[dom]](%[[arg1:.*]]: i1, %[[arg2:.*]]: i1):
|
||||
// CHECK: br ^[[cont:.*]]
|
||||
// CHECK: cf.br ^[[cont:.*]]
|
||||
// CHECK: ^[[cont]]:
|
||||
// CHECK: return %[[arg1]], %[[arg2]]
|
||||
return %0#0, %0#1 : i1, i1
|
||||
|
@ -177,49 +177,49 @@ func @simple_if_yield(%arg0: i1) -> (i1, i1) {
|
|||
|
||||
// CHECK-LABEL: func @nested_if_yield
|
||||
func @nested_if_yield(%arg0: i1) -> (index) {
|
||||
// CHECK: cond_br %{{.*}}, ^[[first_then:.*]], ^[[first_else:.*]]
|
||||
// CHECK: cf.cond_br %{{.*}}, ^[[first_then:.*]], ^[[first_else:.*]]
|
||||
%0 = scf.if %arg0 -> i1 {
|
||||
// CHECK: ^[[first_then]]:
|
||||
%1 = arith.constant true
|
||||
// CHECK: br ^[[first_dom:.*]]({{.*}})
|
||||
// CHECK: cf.br ^[[first_dom:.*]]({{.*}})
|
||||
scf.yield %1 : i1
|
||||
} else {
|
||||
// CHECK: ^[[first_else]]:
|
||||
%2 = arith.constant false
|
||||
// CHECK: br ^[[first_dom]]({{.*}})
|
||||
// CHECK: cf.br ^[[first_dom]]({{.*}})
|
||||
scf.yield %2 : i1
|
||||
}
|
||||
// CHECK: ^[[first_dom]](%[[arg1:.*]]: i1):
|
||||
// CHECK: br ^[[first_cont:.*]]
|
||||
// CHECK: cf.br ^[[first_cont:.*]]
|
||||
// CHECK: ^[[first_cont]]:
|
||||
// CHECK: cond_br %[[arg1]], ^[[second_outer_then:.*]], ^[[second_outer_else:.*]]
|
||||
// CHECK: cf.cond_br %[[arg1]], ^[[second_outer_then:.*]], ^[[second_outer_else:.*]]
|
||||
%1 = scf.if %0 -> index {
|
||||
// CHECK: ^[[second_outer_then]]:
|
||||
// CHECK: cond_br %arg0, ^[[second_inner_then:.*]], ^[[second_inner_else:.*]]
|
||||
// CHECK: cf.cond_br %arg0, ^[[second_inner_then:.*]], ^[[second_inner_else:.*]]
|
||||
%3 = scf.if %arg0 -> index {
|
||||
// CHECK: ^[[second_inner_then]]:
|
||||
%4 = arith.constant 40 : index
|
||||
// CHECK: br ^[[second_inner_dom:.*]]({{.*}})
|
||||
// CHECK: cf.br ^[[second_inner_dom:.*]]({{.*}})
|
||||
scf.yield %4 : index
|
||||
} else {
|
||||
// CHECK: ^[[second_inner_else]]:
|
||||
%5 = arith.constant 41 : index
|
||||
// CHECK: br ^[[second_inner_dom]]({{.*}})
|
||||
// CHECK: cf.br ^[[second_inner_dom]]({{.*}})
|
||||
scf.yield %5 : index
|
||||
}
|
||||
// CHECK: ^[[second_inner_dom]](%[[arg2:.*]]: index):
|
||||
// CHECK: br ^[[second_inner_cont:.*]]
|
||||
// CHECK: cf.br ^[[second_inner_cont:.*]]
|
||||
// CHECK: ^[[second_inner_cont]]:
|
||||
// CHECK: br ^[[second_outer_dom:.*]]({{.*}})
|
||||
// CHECK: cf.br ^[[second_outer_dom:.*]]({{.*}})
|
||||
scf.yield %3 : index
|
||||
} else {
|
||||
// CHECK: ^[[second_outer_else]]:
|
||||
%6 = arith.constant 42 : index
|
||||
// CHECK: br ^[[second_outer_dom]]({{.*}}
|
||||
// CHECK: cf.br ^[[second_outer_dom]]({{.*}}
|
||||
scf.yield %6 : index
|
||||
}
|
||||
// CHECK: ^[[second_outer_dom]](%[[arg3:.*]]: index):
|
||||
// CHECK: br ^[[second_outer_cont:.*]]
|
||||
// CHECK: cf.br ^[[second_outer_cont:.*]]
|
||||
// CHECK: ^[[second_outer_cont]]:
|
||||
// CHECK: return %[[arg3]] : index
|
||||
return %1 : index
|
||||
|
@ -228,22 +228,22 @@ func @nested_if_yield(%arg0: i1) -> (index) {
|
|||
// CHECK-LABEL: func @parallel_loop(
|
||||
// CHECK-SAME: [[VAL_0:%.*]]: index, [[VAL_1:%.*]]: index, [[VAL_2:%.*]]: index, [[VAL_3:%.*]]: index, [[VAL_4:%.*]]: index) {
|
||||
// CHECK: [[VAL_5:%.*]] = arith.constant 1 : index
|
||||
// CHECK: br ^bb1([[VAL_0]] : index)
|
||||
// CHECK: cf.br ^bb1([[VAL_0]] : index)
|
||||
// CHECK: ^bb1([[VAL_6:%.*]]: index):
|
||||
// CHECK: [[VAL_7:%.*]] = arith.cmpi slt, [[VAL_6]], [[VAL_2]] : index
|
||||
// CHECK: cond_br [[VAL_7]], ^bb2, ^bb6
|
||||
// CHECK: cf.cond_br [[VAL_7]], ^bb2, ^bb6
|
||||
// CHECK: ^bb2:
|
||||
// CHECK: br ^bb3([[VAL_1]] : index)
|
||||
// CHECK: cf.br ^bb3([[VAL_1]] : index)
|
||||
// CHECK: ^bb3([[VAL_8:%.*]]: index):
|
||||
// CHECK: [[VAL_9:%.*]] = arith.cmpi slt, [[VAL_8]], [[VAL_3]] : index
|
||||
// CHECK: cond_br [[VAL_9]], ^bb4, ^bb5
|
||||
// CHECK: cf.cond_br [[VAL_9]], ^bb4, ^bb5
|
||||
// CHECK: ^bb4:
|
||||
// CHECK: [[VAL_10:%.*]] = arith.constant 1 : index
|
||||
// CHECK: [[VAL_11:%.*]] = arith.addi [[VAL_8]], [[VAL_5]] : index
|
||||
// CHECK: br ^bb3([[VAL_11]] : index)
|
||||
// CHECK: cf.br ^bb3([[VAL_11]] : index)
|
||||
// CHECK: ^bb5:
|
||||
// CHECK: [[VAL_12:%.*]] = arith.addi [[VAL_6]], [[VAL_4]] : index
|
||||
// CHECK: br ^bb1([[VAL_12]] : index)
|
||||
// CHECK: cf.br ^bb1([[VAL_12]] : index)
|
||||
// CHECK: ^bb6:
|
||||
// CHECK: return
|
||||
// CHECK: }
|
||||
|
@ -262,16 +262,16 @@ func @parallel_loop(%arg0 : index, %arg1 : index, %arg2 : index,
|
|||
// CHECK-SAME: (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index)
|
||||
// CHECK: %[[INIT0:.*]] = arith.constant 0
|
||||
// CHECK: %[[INIT1:.*]] = arith.constant 1
|
||||
// CHECK: br ^[[COND:.*]](%[[LB]], %[[INIT0]], %[[INIT1]] : index, f32, f32)
|
||||
// CHECK: cf.br ^[[COND:.*]](%[[LB]], %[[INIT0]], %[[INIT1]] : index, f32, f32)
|
||||
//
|
||||
// CHECK: ^[[COND]](%[[ITER:.*]]: index, %[[ITER_ARG0:.*]]: f32, %[[ITER_ARG1:.*]]: f32):
|
||||
// CHECK: %[[CMP:.*]] = arith.cmpi slt, %[[ITER]], %[[UB]] : index
|
||||
// CHECK: cond_br %[[CMP]], ^[[BODY:.*]], ^[[CONTINUE:.*]]
|
||||
// CHECK: cf.cond_br %[[CMP]], ^[[BODY:.*]], ^[[CONTINUE:.*]]
|
||||
//
|
||||
// CHECK: ^[[BODY]]:
|
||||
// CHECK: %[[SUM:.*]] = arith.addf %[[ITER_ARG0]], %[[ITER_ARG1]] : f32
|
||||
// CHECK: %[[STEPPED:.*]] = arith.addi %[[ITER]], %[[STEP]] : index
|
||||
// CHECK: br ^[[COND]](%[[STEPPED]], %[[SUM]], %[[SUM]] : index, f32, f32)
|
||||
// CHECK: cf.br ^[[COND]](%[[STEPPED]], %[[SUM]], %[[SUM]] : index, f32, f32)
|
||||
//
|
||||
// CHECK: ^[[CONTINUE]]:
|
||||
// CHECK: return %[[ITER_ARG0]], %[[ITER_ARG1]] : f32, f32
|
||||
|
@ -288,18 +288,18 @@ func @for_yield(%arg0 : index, %arg1 : index, %arg2 : index) -> (f32, f32) {
|
|||
// CHECK-LABEL: @nested_for_yield
|
||||
// CHECK-SAME: (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index)
|
||||
// CHECK: %[[INIT:.*]] = arith.constant
|
||||
// CHECK: br ^[[COND_OUT:.*]](%[[LB]], %[[INIT]] : index, f32)
|
||||
// CHECK: cf.br ^[[COND_OUT:.*]](%[[LB]], %[[INIT]] : index, f32)
|
||||
// CHECK: ^[[COND_OUT]](%[[ITER_OUT:.*]]: index, %[[ARG_OUT:.*]]: f32):
|
||||
// CHECK: cond_br %{{.*}}, ^[[BODY_OUT:.*]], ^[[CONT_OUT:.*]]
|
||||
// CHECK: cf.cond_br %{{.*}}, ^[[BODY_OUT:.*]], ^[[CONT_OUT:.*]]
|
||||
// CHECK: ^[[BODY_OUT]]:
|
||||
// CHECK: br ^[[COND_IN:.*]](%[[LB]], %[[ARG_OUT]] : index, f32)
|
||||
// CHECK: cf.br ^[[COND_IN:.*]](%[[LB]], %[[ARG_OUT]] : index, f32)
|
||||
// CHECK: ^[[COND_IN]](%[[ITER_IN:.*]]: index, %[[ARG_IN:.*]]: f32):
|
||||
// CHECK: cond_br %{{.*}}, ^[[BODY_IN:.*]], ^[[CONT_IN:.*]]
|
||||
// CHECK: cf.cond_br %{{.*}}, ^[[BODY_IN:.*]], ^[[CONT_IN:.*]]
|
||||
// CHECK: ^[[BODY_IN]]
|
||||
// CHECK: %[[RES:.*]] = arith.addf
|
||||
// CHECK: br ^[[COND_IN]](%{{.*}}, %[[RES]] : index, f32)
|
||||
// CHECK: cf.br ^[[COND_IN]](%{{.*}}, %[[RES]] : index, f32)
|
||||
// CHECK: ^[[CONT_IN]]:
|
||||
// CHECK: br ^[[COND_OUT]](%{{.*}}, %[[ARG_IN]] : index, f32)
|
||||
// CHECK: cf.br ^[[COND_OUT]](%{{.*}}, %[[ARG_IN]] : index, f32)
|
||||
// CHECK: ^[[CONT_OUT]]:
|
||||
// CHECK: return %[[ARG_OUT]] : f32
|
||||
func @nested_for_yield(%arg0 : index, %arg1 : index, %arg2 : index) -> f32 {
|
||||
|
@ -325,13 +325,13 @@ func @simple_parallel_reduce_loop(%arg0: index, %arg1: index,
|
|||
// passed across as a block argument.
|
||||
|
||||
// Branch to the condition block passing in the initial reduction value.
|
||||
// CHECK: br ^[[COND:.*]](%[[LB]], %[[INIT]]
|
||||
// CHECK: cf.br ^[[COND:.*]](%[[LB]], %[[INIT]]
|
||||
|
||||
// Condition branch takes as arguments the current value of the iteration
|
||||
// variable and the current partially reduced value.
|
||||
// CHECK: ^[[COND]](%[[ITER:.*]]: index, %[[ITER_ARG:.*]]: f32
|
||||
// CHECK: %[[COMP:.*]] = arith.cmpi slt, %[[ITER]], %[[UB]]
|
||||
// CHECK: cond_br %[[COMP]], ^[[BODY:.*]], ^[[CONTINUE:.*]]
|
||||
// CHECK: cf.cond_br %[[COMP]], ^[[BODY:.*]], ^[[CONTINUE:.*]]
|
||||
|
||||
// Bodies of scf.reduce operations are folded into the main loop body. The
|
||||
// result of this partial reduction is passed as argument to the condition
|
||||
|
@ -340,7 +340,7 @@ func @simple_parallel_reduce_loop(%arg0: index, %arg1: index,
|
|||
// CHECK: %[[CST:.*]] = arith.constant 4.2
|
||||
// CHECK: %[[PROD:.*]] = arith.mulf %[[ITER_ARG]], %[[CST]]
|
||||
// CHECK: %[[INCR:.*]] = arith.addi %[[ITER]], %[[STEP]]
|
||||
// CHECK: br ^[[COND]](%[[INCR]], %[[PROD]]
|
||||
// CHECK: cf.br ^[[COND]](%[[INCR]], %[[PROD]]
|
||||
|
||||
// The continuation block has access to the (last value of) reduction.
|
||||
// CHECK: ^[[CONTINUE]]:
|
||||
|
@ -363,19 +363,19 @@ func @parallel_reduce_loop(%arg0 : index, %arg1 : index, %arg2 : index,
|
|||
// Multiple reduction blocks should be folded in the same body, and the
|
||||
// reduction value must be forwarded through block structures.
|
||||
// CHECK: %[[INIT2:.*]] = arith.constant 42
|
||||
// CHECK: br ^[[COND_OUT:.*]](%{{.*}}, %[[INIT1]], %[[INIT2]]
|
||||
// CHECK: cf.br ^[[COND_OUT:.*]](%{{.*}}, %[[INIT1]], %[[INIT2]]
|
||||
// CHECK: ^[[COND_OUT]](%{{.*}}: index, %[[ITER_ARG1_OUT:.*]]: f32, %[[ITER_ARG2_OUT:.*]]: i64
|
||||
// CHECK: cond_br %{{.*}}, ^[[BODY_OUT:.*]], ^[[CONT_OUT:.*]]
|
||||
// CHECK: cf.cond_br %{{.*}}, ^[[BODY_OUT:.*]], ^[[CONT_OUT:.*]]
|
||||
// CHECK: ^[[BODY_OUT]]:
|
||||
// CHECK: br ^[[COND_IN:.*]](%{{.*}}, %[[ITER_ARG1_OUT]], %[[ITER_ARG2_OUT]]
|
||||
// CHECK: cf.br ^[[COND_IN:.*]](%{{.*}}, %[[ITER_ARG1_OUT]], %[[ITER_ARG2_OUT]]
|
||||
// CHECK: ^[[COND_IN]](%{{.*}}: index, %[[ITER_ARG1_IN:.*]]: f32, %[[ITER_ARG2_IN:.*]]: i64
|
||||
// CHECK: cond_br %{{.*}}, ^[[BODY_IN:.*]], ^[[CONT_IN:.*]]
|
||||
// CHECK: cf.cond_br %{{.*}}, ^[[BODY_IN:.*]], ^[[CONT_IN:.*]]
|
||||
// CHECK: ^[[BODY_IN]]:
|
||||
// CHECK: %[[REDUCE1:.*]] = arith.addf %[[ITER_ARG1_IN]], %{{.*}}
|
||||
// CHECK: %[[REDUCE2:.*]] = arith.ori %[[ITER_ARG2_IN]], %{{.*}}
|
||||
// CHECK: br ^[[COND_IN]](%{{.*}}, %[[REDUCE1]], %[[REDUCE2]]
|
||||
// CHECK: cf.br ^[[COND_IN]](%{{.*}}, %[[REDUCE1]], %[[REDUCE2]]
|
||||
// CHECK: ^[[CONT_IN]]:
|
||||
// CHECK: br ^[[COND_OUT]](%{{.*}}, %[[ITER_ARG1_IN]], %[[ITER_ARG2_IN]]
|
||||
// CHECK: cf.br ^[[COND_OUT]](%{{.*}}, %[[ITER_ARG1_IN]], %[[ITER_ARG2_IN]]
|
||||
// CHECK: ^[[CONT_OUT]]:
|
||||
// CHECK: return %[[ITER_ARG1_OUT]], %[[ITER_ARG2_OUT]]
|
||||
%step = arith.constant 1 : index
|
||||
|
@ -416,17 +416,17 @@ func @unknown_op_inside_loop(%arg0: index, %arg1: index, %arg2: index) {
|
|||
// CHECK-LABEL: @minimal_while
|
||||
func @minimal_while() {
|
||||
// CHECK: %[[COND:.*]] = "test.make_condition"() : () -> i1
|
||||
// CHECK: br ^[[BEFORE:.*]]
|
||||
// CHECK: cf.br ^[[BEFORE:.*]]
|
||||
%0 = "test.make_condition"() : () -> i1
|
||||
scf.while : () -> () {
|
||||
// CHECK: ^[[BEFORE]]:
|
||||
// CHECK: cond_br %[[COND]], ^[[AFTER:.*]], ^[[CONT:.*]]
|
||||
// CHECK: cf.cond_br %[[COND]], ^[[AFTER:.*]], ^[[CONT:.*]]
|
||||
scf.condition(%0)
|
||||
} do {
|
||||
// CHECK: ^[[AFTER]]:
|
||||
// CHECK: "test.some_payload"() : () -> ()
|
||||
"test.some_payload"() : () -> ()
|
||||
// CHECK: br ^[[BEFORE]]
|
||||
// CHECK: cf.br ^[[BEFORE]]
|
||||
scf.yield
|
||||
}
|
||||
// CHECK: ^[[CONT]]:
|
||||
|
@ -436,16 +436,16 @@ func @minimal_while() {
|
|||
|
||||
// CHECK-LABEL: @do_while
|
||||
func @do_while(%arg0: f32) {
|
||||
// CHECK: br ^[[BEFORE:.*]]({{.*}}: f32)
|
||||
// CHECK: cf.br ^[[BEFORE:.*]]({{.*}}: f32)
|
||||
scf.while (%arg1 = %arg0) : (f32) -> (f32) {
|
||||
// CHECK: ^[[BEFORE]](%[[VAL:.*]]: f32):
|
||||
// CHECK: %[[COND:.*]] = "test.make_condition"() : () -> i1
|
||||
%0 = "test.make_condition"() : () -> i1
|
||||
// CHECK: cond_br %[[COND]], ^[[BEFORE]](%[[VAL]] : f32), ^[[CONT:.*]]
|
||||
// CHECK: cf.cond_br %[[COND]], ^[[BEFORE]](%[[VAL]] : f32), ^[[CONT:.*]]
|
||||
scf.condition(%0) %arg1 : f32
|
||||
} do {
|
||||
^bb0(%arg2: f32):
|
||||
// CHECK-NOT: br ^[[BEFORE]]
|
||||
// CHECK-NOT: cf.br ^[[BEFORE]]
|
||||
scf.yield %arg2 : f32
|
||||
}
|
||||
// CHECK: ^[[CONT]]:
|
||||
|
@ -460,21 +460,21 @@ func @while_values(%arg0: i32, %arg1: f32) {
|
|||
%0 = "test.make_condition"() : () -> i1
|
||||
%c0_i32 = arith.constant 0 : i32
|
||||
%cst = arith.constant 0.000000e+00 : f32
|
||||
// CHECK: br ^[[BEFORE:.*]](%[[ARG0]], %[[ARG1]] : i32, f32)
|
||||
// CHECK: cf.br ^[[BEFORE:.*]](%[[ARG0]], %[[ARG1]] : i32, f32)
|
||||
%1:2 = scf.while (%arg2 = %arg0, %arg3 = %arg1) : (i32, f32) -> (i64, f64) {
|
||||
// CHECK: ^bb1(%[[ARG2:.*]]: i32, %[[ARG3:.]]: f32):
|
||||
// CHECK: %[[VAL1:.*]] = arith.extui %[[ARG0]] : i32 to i64
|
||||
%2 = arith.extui %arg0 : i32 to i64
|
||||
// CHECK: %[[VAL2:.*]] = arith.extf %[[ARG3]] : f32 to f64
|
||||
%3 = arith.extf %arg3 : f32 to f64
|
||||
// CHECK: cond_br %[[COND]],
|
||||
// CHECK: cf.cond_br %[[COND]],
|
||||
// CHECK: ^[[AFTER:.*]](%[[VAL1]], %[[VAL2]] : i64, f64),
|
||||
// CHECK: ^[[CONT:.*]]
|
||||
scf.condition(%0) %2, %3 : i64, f64
|
||||
} do {
|
||||
// CHECK: ^[[AFTER]](%[[ARG4:.*]]: i64, %[[ARG5:.*]]: f64):
|
||||
^bb0(%arg2: i64, %arg3: f64):
|
||||
// CHECK: br ^[[BEFORE]](%{{.*}}, %{{.*}} : i32, f32)
|
||||
// CHECK: cf.br ^[[BEFORE]](%{{.*}}, %{{.*}} : i32, f32)
|
||||
scf.yield %c0_i32, %cst : i32, f32
|
||||
}
|
||||
// CHECK: ^bb3:
|
||||
|
@ -484,17 +484,17 @@ func @while_values(%arg0: i32, %arg1: f32) {
|
|||
|
||||
// CHECK-LABEL: @nested_while_ops
|
||||
func @nested_while_ops(%arg0: f32) -> i64 {
|
||||
// CHECK: br ^[[OUTER_BEFORE:.*]](%{{.*}} : f32)
|
||||
// CHECK: cf.br ^[[OUTER_BEFORE:.*]](%{{.*}} : f32)
|
||||
%0 = scf.while(%outer = %arg0) : (f32) -> i64 {
|
||||
// CHECK: ^[[OUTER_BEFORE]](%{{.*}}: f32):
|
||||
// CHECK: %[[OUTER_COND:.*]] = "test.outer_before_pre"() : () -> i1
|
||||
%cond = "test.outer_before_pre"() : () -> i1
|
||||
// CHECK: br ^[[INNER_BEFORE_BEFORE:.*]](%{{.*}} : f32)
|
||||
// CHECK: cf.br ^[[INNER_BEFORE_BEFORE:.*]](%{{.*}} : f32)
|
||||
%1 = scf.while(%inner = %outer) : (f32) -> i64 {
|
||||
// CHECK: ^[[INNER_BEFORE_BEFORE]](%{{.*}}: f32):
|
||||
// CHECK: %[[INNER1:.*]]:2 = "test.inner_before"(%{{.*}}) : (f32) -> (i1, i64)
|
||||
%2:2 = "test.inner_before"(%inner) : (f32) -> (i1, i64)
|
||||
// CHECK: cond_br %[[INNER1]]#0,
|
||||
// CHECK: cf.cond_br %[[INNER1]]#0,
|
||||
// CHECK: ^[[INNER_BEFORE_AFTER:.*]](%[[INNER1]]#1 : i64),
|
||||
// CHECK: ^[[OUTER_BEFORE_LAST:.*]]
|
||||
scf.condition(%2#0) %2#1 : i64
|
||||
|
@ -503,13 +503,13 @@ func @nested_while_ops(%arg0: f32) -> i64 {
|
|||
^bb0(%arg1: i64):
|
||||
// CHECK: %[[INNER2:.*]] = "test.inner_after"(%{{.*}}) : (i64) -> f32
|
||||
%3 = "test.inner_after"(%arg1) : (i64) -> f32
|
||||
// CHECK: br ^[[INNER_BEFORE_BEFORE]](%[[INNER2]] : f32)
|
||||
// CHECK: cf.br ^[[INNER_BEFORE_BEFORE]](%[[INNER2]] : f32)
|
||||
scf.yield %3 : f32
|
||||
}
|
||||
// CHECK: ^[[OUTER_BEFORE_LAST]]:
|
||||
// CHECK: "test.outer_before_post"() : () -> ()
|
||||
"test.outer_before_post"() : () -> ()
|
||||
// CHECK: cond_br %[[OUTER_COND]],
|
||||
// CHECK: cf.cond_br %[[OUTER_COND]],
|
||||
// CHECK: ^[[OUTER_AFTER:.*]](%[[INNER1]]#1 : i64),
|
||||
// CHECK: ^[[CONTINUATION:.*]]
|
||||
scf.condition(%cond) %1 : i64
|
||||
|
@ -518,12 +518,12 @@ func @nested_while_ops(%arg0: f32) -> i64 {
|
|||
^bb2(%arg2: i64):
|
||||
// CHECK: "test.outer_after_pre"(%{{.*}}) : (i64) -> ()
|
||||
"test.outer_after_pre"(%arg2) : (i64) -> ()
|
||||
// CHECK: br ^[[INNER_AFTER_BEFORE:.*]](%{{.*}} : i64)
|
||||
// CHECK: cf.br ^[[INNER_AFTER_BEFORE:.*]](%{{.*}} : i64)
|
||||
%4 = scf.while(%inner = %arg2) : (i64) -> f32 {
|
||||
// CHECK: ^[[INNER_AFTER_BEFORE]](%{{.*}}: i64):
|
||||
// CHECK: %[[INNER3:.*]]:2 = "test.inner2_before"(%{{.*}}) : (i64) -> (i1, f32)
|
||||
%5:2 = "test.inner2_before"(%inner) : (i64) -> (i1, f32)
|
||||
// CHECK: cond_br %[[INNER3]]#0,
|
||||
// CHECK: cf.cond_br %[[INNER3]]#0,
|
||||
// CHECK: ^[[INNER_AFTER_AFTER:.*]](%[[INNER3]]#1 : f32),
|
||||
// CHECK: ^[[OUTER_AFTER_LAST:.*]]
|
||||
scf.condition(%5#0) %5#1 : f32
|
||||
|
@ -532,13 +532,13 @@ func @nested_while_ops(%arg0: f32) -> i64 {
|
|||
^bb3(%arg3: f32):
|
||||
// CHECK: %{{.*}} = "test.inner2_after"(%{{.*}}) : (f32) -> i64
|
||||
%6 = "test.inner2_after"(%arg3) : (f32) -> i64
|
||||
// CHECK: br ^[[INNER_AFTER_BEFORE]](%{{.*}} : i64)
|
||||
// CHECK: cf.br ^[[INNER_AFTER_BEFORE]](%{{.*}} : i64)
|
||||
scf.yield %6 : i64
|
||||
}
|
||||
// CHECK: ^[[OUTER_AFTER_LAST]]:
|
||||
// CHECK: "test.outer_after_post"() : () -> ()
|
||||
"test.outer_after_post"() : () -> ()
|
||||
// CHECK: br ^[[OUTER_BEFORE]](%[[INNER3]]#1 : f32)
|
||||
// CHECK: cf.br ^[[OUTER_BEFORE]](%[[INNER3]]#1 : f32)
|
||||
scf.yield %4 : f32
|
||||
}
|
||||
// CHECK: ^[[CONTINUATION]]:
|
||||
|
@ -549,27 +549,27 @@ func @nested_while_ops(%arg0: f32) -> i64 {
|
|||
// CHECK-LABEL: @ifs_in_parallel
|
||||
// CHECK: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: i1, %[[ARG4:.*]]: i1)
|
||||
func @ifs_in_parallel(%arg1: index, %arg2: index, %arg3: index, %arg4: i1, %arg5: i1) {
|
||||
// CHECK: br ^[[LOOP_LATCH:.*]](%[[ARG0]] : index)
|
||||
// CHECK: cf.br ^[[LOOP_LATCH:.*]](%[[ARG0]] : index)
|
||||
// CHECK: ^[[LOOP_LATCH]](%[[LOOP_IV:.*]]: index):
|
||||
// CHECK: %[[LOOP_COND:.*]] = arith.cmpi slt, %[[LOOP_IV]], %[[ARG1]] : index
|
||||
// CHECK: cond_br %[[LOOP_COND]], ^[[LOOP_BODY:.*]], ^[[LOOP_CONT:.*]]
|
||||
// CHECK: cf.cond_br %[[LOOP_COND]], ^[[LOOP_BODY:.*]], ^[[LOOP_CONT:.*]]
|
||||
// CHECK: ^[[LOOP_BODY]]:
|
||||
// CHECK: cond_br %[[ARG3]], ^[[IF1_THEN:.*]], ^[[IF1_CONT:.*]]
|
||||
// CHECK: cf.cond_br %[[ARG3]], ^[[IF1_THEN:.*]], ^[[IF1_CONT:.*]]
|
||||
// CHECK: ^[[IF1_THEN]]:
|
||||
// CHECK: cond_br %[[ARG4]], ^[[IF2_THEN:.*]], ^[[IF2_ELSE:.*]]
|
||||
// CHECK: cf.cond_br %[[ARG4]], ^[[IF2_THEN:.*]], ^[[IF2_ELSE:.*]]
|
||||
// CHECK: ^[[IF2_THEN]]:
|
||||
// CHECK: %{{.*}} = "test.if2"() : () -> index
|
||||
// CHECK: br ^[[IF2_MERGE:.*]](%{{.*}} : index)
|
||||
// CHECK: cf.br ^[[IF2_MERGE:.*]](%{{.*}} : index)
|
||||
// CHECK: ^[[IF2_ELSE]]:
|
||||
// CHECK: %{{.*}} = "test.else2"() : () -> index
|
||||
// CHECK: br ^[[IF2_MERGE]](%{{.*}} : index)
|
||||
// CHECK: cf.br ^[[IF2_MERGE]](%{{.*}} : index)
|
||||
// CHECK: ^[[IF2_MERGE]](%{{.*}}: index):
|
||||
// CHECK: br ^[[IF2_CONT:.*]]
|
||||
// CHECK: cf.br ^[[IF2_CONT:.*]]
|
||||
// CHECK: ^[[IF2_CONT]]:
|
||||
// CHECK: br ^[[IF1_CONT]]
|
||||
// CHECK: cf.br ^[[IF1_CONT]]
|
||||
// CHECK: ^[[IF1_CONT]]:
|
||||
// CHECK: %{{.*}} = arith.addi %[[LOOP_IV]], %[[ARG2]] : index
|
||||
// CHECK: br ^[[LOOP_LATCH]](%{{.*}} : index)
|
||||
// CHECK: cf.br ^[[LOOP_LATCH]](%{{.*}} : index)
|
||||
scf.parallel (%i) = (%arg1) to (%arg2) step (%arg3) {
|
||||
scf.if %arg4 {
|
||||
%0 = scf.if %arg5 -> (index) {
|
||||
|
@ -593,7 +593,7 @@ func @func_execute_region_elim_multi_yield() {
|
|||
"test.foo"() : () -> ()
|
||||
%v = scf.execute_region -> i64 {
|
||||
%c = "test.cmp"() : () -> i1
|
||||
cond_br %c, ^bb2, ^bb3
|
||||
cf.cond_br %c, ^bb2, ^bb3
|
||||
^bb2:
|
||||
%x = "test.val1"() : () -> i64
|
||||
scf.yield %x : i64
|
||||
|
@ -607,16 +607,16 @@ func @func_execute_region_elim_multi_yield() {
|
|||
|
||||
// CHECK-NOT: execute_region
|
||||
// CHECK: "test.foo"
|
||||
// CHECK: br ^[[rentry:.+]]
|
||||
// CHECK: cf.br ^[[rentry:.+]]
|
||||
// CHECK: ^[[rentry]]
|
||||
// CHECK: %[[cmp:.+]] = "test.cmp"
|
||||
// CHECK: cond_br %[[cmp]], ^[[bb1:.+]], ^[[bb2:.+]]
|
||||
// CHECK: cf.cond_br %[[cmp]], ^[[bb1:.+]], ^[[bb2:.+]]
|
||||
// CHECK: ^[[bb1]]:
|
||||
// CHECK: %[[x:.+]] = "test.val1"
|
||||
// CHECK: br ^[[bb3:.+]](%[[x]] : i64)
|
||||
// CHECK: cf.br ^[[bb3:.+]](%[[x]] : i64)
|
||||
// CHECK: ^[[bb2]]:
|
||||
// CHECK: %[[y:.+]] = "test.val2"
|
||||
// CHECK: br ^[[bb3]](%[[y:.+]] : i64)
|
||||
// CHECK: cf.br ^[[bb3]](%[[y:.+]] : i64)
|
||||
// CHECK: ^[[bb3]](%[[z:.+]]: i64):
|
||||
// CHECK: "test.bar"(%[[z]])
|
||||
// CHECK: return
|
|
@ -6,7 +6,7 @@
|
|||
// CHECK-SAME: %[[RHS:.*]]: tensor<?xindex>) -> !shape.witness {
|
||||
// CHECK: %[[RET:.*]] = shape.const_witness true
|
||||
// CHECK: %[[BROADCAST_IS_VALID:.*]] = shape.is_broadcastable %[[LHS]], %[[RHS]]
|
||||
// CHECK: assert %[[BROADCAST_IS_VALID]], "required broadcastable shapes"
|
||||
// CHECK: cf.assert %[[BROADCAST_IS_VALID]], "required broadcastable shapes"
|
||||
// CHECK: return %[[RET]] : !shape.witness
|
||||
// CHECK: }
|
||||
func @cstr_broadcastable(%arg0: tensor<?xindex>, %arg1: tensor<?xindex>) -> !shape.witness {
|
||||
|
@ -19,7 +19,7 @@ func @cstr_broadcastable(%arg0: tensor<?xindex>, %arg1: tensor<?xindex>) -> !sha
|
|||
// CHECK-SAME: %[[RHS:.*]]: tensor<?xindex>) -> !shape.witness {
|
||||
// CHECK: %[[RET:.*]] = shape.const_witness true
|
||||
// CHECK: %[[EQUAL_IS_VALID:.*]] = shape.shape_eq %[[LHS]], %[[RHS]]
|
||||
// CHECK: assert %[[EQUAL_IS_VALID]], "required equal shapes"
|
||||
// CHECK: cf.assert %[[EQUAL_IS_VALID]], "required equal shapes"
|
||||
// CHECK: return %[[RET]] : !shape.witness
|
||||
// CHECK: }
|
||||
func @cstr_eq(%arg0: tensor<?xindex>, %arg1: tensor<?xindex>) -> !shape.witness {
|
||||
|
@ -30,7 +30,7 @@ func @cstr_eq(%arg0: tensor<?xindex>, %arg1: tensor<?xindex>) -> !shape.witness
|
|||
// CHECK-LABEL: func @cstr_require
|
||||
func @cstr_require(%arg0: i1) -> !shape.witness {
|
||||
// CHECK: %[[RET:.*]] = shape.const_witness true
|
||||
// CHECK: assert %arg0, "msg"
|
||||
// CHECK: cf.assert %arg0, "msg"
|
||||
// CHECK: return %[[RET]]
|
||||
%witness = shape.cstr_require %arg0, "msg"
|
||||
return %witness : !shape.witness
|
||||
|
|
|
@ -29,7 +29,7 @@ func private @memref_call_conv_nested(%arg0: (memref<?xf32>) -> ())
|
|||
//CHECK-LABEL: llvm.func @pass_through(%arg0: !llvm.ptr<func<void ()>>) -> !llvm.ptr<func<void ()>> {
|
||||
func @pass_through(%arg0: () -> ()) -> (() -> ()) {
|
||||
// CHECK-NEXT: llvm.br ^bb1(%arg0 : !llvm.ptr<func<void ()>>)
|
||||
br ^bb1(%arg0 : () -> ())
|
||||
cf.br ^bb1(%arg0 : () -> ())
|
||||
|
||||
//CHECK-NEXT: ^bb1(%0: !llvm.ptr<func<void ()>>):
|
||||
^bb1(%bbarg: () -> ()):
|
||||
|
|
|
@ -109,17 +109,17 @@ func @loop_carried(%arg0 : index, %arg1 : index, %arg2 : index, %base0 : !base_t
|
|||
// This test checks that in the BAREPTR case, the branch arguments only forward the descriptor.
|
||||
// This test was lowered from a simple scf.for that swaps 2 memref iter_args.
|
||||
// BAREPTR: llvm.br ^bb1(%{{.*}}, %{{.*}}, %{{.*}} : i64, !llvm.struct<(ptr<i32, 201>, ptr<i32, 201>, i64, array<1 x i64>, array<1 x i64>)>, !llvm.struct<(ptr<i32, 201>, ptr<i32, 201>, i64, array<1 x i64>, array<1 x i64>)>)
|
||||
br ^bb1(%arg0, %base0, %base1 : index, memref<64xi32, 201>, memref<64xi32, 201>)
|
||||
cf.br ^bb1(%arg0, %base0, %base1 : index, memref<64xi32, 201>, memref<64xi32, 201>)
|
||||
|
||||
// BAREPTR-NEXT: ^bb1
|
||||
// BAREPTR-NEXT: llvm.icmp
|
||||
// BAREPTR-NEXT: llvm.cond_br %{{.*}}, ^bb2, ^bb3
|
||||
^bb1(%0: index, %1: memref<64xi32, 201>, %2: memref<64xi32, 201>): // 2 preds: ^bb0, ^bb2
|
||||
%3 = arith.cmpi slt, %0, %arg1 : index
|
||||
cond_br %3, ^bb2, ^bb3
|
||||
cf.cond_br %3, ^bb2, ^bb3
|
||||
^bb2: // pred: ^bb1
|
||||
%4 = arith.addi %0, %arg2 : index
|
||||
br ^bb1(%4, %2, %1 : index, memref<64xi32, 201>, memref<64xi32, 201>)
|
||||
cf.br ^bb1(%4, %2, %1 : index, memref<64xi32, 201>, memref<64xi32, 201>)
|
||||
^bb3: // pred: ^bb1
|
||||
return %1, %2 : memref<64xi32, 201>, memref<64xi32, 201>
|
||||
}
|
||||
|
|
|
@ -18,7 +18,7 @@ func @simple_loop() {
|
|||
^bb0:
|
||||
// CHECK-NEXT: llvm.br ^bb1
|
||||
// CHECK32-NEXT: llvm.br ^bb1
|
||||
br ^bb1
|
||||
cf.br ^bb1
|
||||
|
||||
// CHECK-NEXT: ^bb1: // pred: ^bb0
|
||||
// CHECK-NEXT: {{.*}} = llvm.mlir.constant(1 : index) : i64
|
||||
|
@ -31,7 +31,7 @@ func @simple_loop() {
|
|||
^bb1: // pred: ^bb0
|
||||
%c1 = arith.constant 1 : index
|
||||
%c42 = arith.constant 42 : index
|
||||
br ^bb2(%c1 : index)
|
||||
cf.br ^bb2(%c1 : index)
|
||||
|
||||
// CHECK: ^bb2({{.*}}: i64): // 2 preds: ^bb1, ^bb3
|
||||
// CHECK-NEXT: {{.*}} = llvm.icmp "slt" {{.*}}, {{.*}} : i64
|
||||
|
@ -41,7 +41,7 @@ func @simple_loop() {
|
|||
// CHECK32-NEXT: llvm.cond_br {{.*}}, ^bb3, ^bb4
|
||||
^bb2(%0: index): // 2 preds: ^bb1, ^bb3
|
||||
%1 = arith.cmpi slt, %0, %c42 : index
|
||||
cond_br %1, ^bb3, ^bb4
|
||||
cf.cond_br %1, ^bb3, ^bb4
|
||||
|
||||
// CHECK: ^bb3: // pred: ^bb2
|
||||
// CHECK-NEXT: llvm.call @body({{.*}}) : (i64) -> ()
|
||||
|
@ -57,7 +57,7 @@ func @simple_loop() {
|
|||
call @body(%0) : (index) -> ()
|
||||
%c1_0 = arith.constant 1 : index
|
||||
%2 = arith.addi %0, %c1_0 : index
|
||||
br ^bb2(%2 : index)
|
||||
cf.br ^bb2(%2 : index)
|
||||
|
||||
// CHECK: ^bb4: // pred: ^bb2
|
||||
// CHECK-NEXT: llvm.return
|
||||
|
@ -111,7 +111,7 @@ func private @other(index, i32) -> i32
|
|||
func @func_args(i32, i32) -> i32 {
|
||||
^bb0(%arg0: i32, %arg1: i32):
|
||||
%c0_i32 = arith.constant 0 : i32
|
||||
br ^bb1
|
||||
cf.br ^bb1
|
||||
|
||||
// CHECK-NEXT: ^bb1: // pred: ^bb0
|
||||
// CHECK-NEXT: {{.*}} = llvm.mlir.constant(0 : index) : i64
|
||||
|
@ -124,7 +124,7 @@ func @func_args(i32, i32) -> i32 {
|
|||
^bb1: // pred: ^bb0
|
||||
%c0 = arith.constant 0 : index
|
||||
%c42 = arith.constant 42 : index
|
||||
br ^bb2(%c0 : index)
|
||||
cf.br ^bb2(%c0 : index)
|
||||
|
||||
// CHECK-NEXT: ^bb2({{.*}}: i64): // 2 preds: ^bb1, ^bb3
|
||||
// CHECK-NEXT: {{.*}} = llvm.icmp "slt" {{.*}}, {{.*}} : i64
|
||||
|
@ -134,7 +134,7 @@ func @func_args(i32, i32) -> i32 {
|
|||
// CHECK32-NEXT: llvm.cond_br {{.*}}, ^bb3, ^bb4
|
||||
^bb2(%0: index): // 2 preds: ^bb1, ^bb3
|
||||
%1 = arith.cmpi slt, %0, %c42 : index
|
||||
cond_br %1, ^bb3, ^bb4
|
||||
cf.cond_br %1, ^bb3, ^bb4
|
||||
|
||||
// CHECK-NEXT: ^bb3: // pred: ^bb2
|
||||
// CHECK-NEXT: {{.*}} = llvm.call @body_args({{.*}}) : (i64) -> i64
|
||||
|
@ -159,7 +159,7 @@ func @func_args(i32, i32) -> i32 {
|
|||
%5 = call @other(%2, %arg1) : (index, i32) -> i32
|
||||
%c1 = arith.constant 1 : index
|
||||
%6 = arith.addi %0, %c1 : index
|
||||
br ^bb2(%6 : index)
|
||||
cf.br ^bb2(%6 : index)
|
||||
|
||||
// CHECK-NEXT: ^bb4: // pred: ^bb2
|
||||
// CHECK-NEXT: {{.*}} = llvm.mlir.constant(0 : index) : i64
|
||||
|
@ -191,7 +191,7 @@ func private @post(index)
|
|||
// CHECK-NEXT: llvm.br ^bb1
|
||||
func @imperfectly_nested_loops() {
|
||||
^bb0:
|
||||
br ^bb1
|
||||
cf.br ^bb1
|
||||
|
||||
// CHECK-NEXT: ^bb1: // pred: ^bb0
|
||||
// CHECK-NEXT: {{.*}} = llvm.mlir.constant(0 : index) : i64
|
||||
|
@ -200,21 +200,21 @@ func @imperfectly_nested_loops() {
|
|||
^bb1: // pred: ^bb0
|
||||
%c0 = arith.constant 0 : index
|
||||
%c42 = arith.constant 42 : index
|
||||
br ^bb2(%c0 : index)
|
||||
cf.br ^bb2(%c0 : index)
|
||||
|
||||
// CHECK-NEXT: ^bb2({{.*}}: i64): // 2 preds: ^bb1, ^bb7
|
||||
// CHECK-NEXT: {{.*}} = llvm.icmp "slt" {{.*}}, {{.*}} : i64
|
||||
// CHECK-NEXT: llvm.cond_br {{.*}}, ^bb3, ^bb8
|
||||
^bb2(%0: index): // 2 preds: ^bb1, ^bb7
|
||||
%1 = arith.cmpi slt, %0, %c42 : index
|
||||
cond_br %1, ^bb3, ^bb8
|
||||
cf.cond_br %1, ^bb3, ^bb8
|
||||
|
||||
// CHECK-NEXT: ^bb3:
|
||||
// CHECK-NEXT: llvm.call @pre({{.*}}) : (i64) -> ()
|
||||
// CHECK-NEXT: llvm.br ^bb4
|
||||
^bb3: // pred: ^bb2
|
||||
call @pre(%0) : (index) -> ()
|
||||
br ^bb4
|
||||
cf.br ^bb4
|
||||
|
||||
// CHECK-NEXT: ^bb4: // pred: ^bb3
|
||||
// CHECK-NEXT: {{.*}} = llvm.mlir.constant(7 : index) : i64
|
||||
|
@ -223,14 +223,14 @@ func @imperfectly_nested_loops() {
|
|||
^bb4: // pred: ^bb3
|
||||
%c7 = arith.constant 7 : index
|
||||
%c56 = arith.constant 56 : index
|
||||
br ^bb5(%c7 : index)
|
||||
cf.br ^bb5(%c7 : index)
|
||||
|
||||
// CHECK-NEXT: ^bb5({{.*}}: i64): // 2 preds: ^bb4, ^bb6
|
||||
// CHECK-NEXT: {{.*}} = llvm.icmp "slt" {{.*}}, {{.*}} : i64
|
||||
// CHECK-NEXT: llvm.cond_br {{.*}}, ^bb6, ^bb7
|
||||
^bb5(%2: index): // 2 preds: ^bb4, ^bb6
|
||||
%3 = arith.cmpi slt, %2, %c56 : index
|
||||
cond_br %3, ^bb6, ^bb7
|
||||
cf.cond_br %3, ^bb6, ^bb7
|
||||
|
||||
// CHECK-NEXT: ^bb6: // pred: ^bb5
|
||||
// CHECK-NEXT: llvm.call @body2({{.*}}, {{.*}}) : (i64, i64) -> ()
|
||||
|
@ -241,7 +241,7 @@ func @imperfectly_nested_loops() {
|
|||
call @body2(%0, %2) : (index, index) -> ()
|
||||
%c2 = arith.constant 2 : index
|
||||
%4 = arith.addi %2, %c2 : index
|
||||
br ^bb5(%4 : index)
|
||||
cf.br ^bb5(%4 : index)
|
||||
|
||||
// CHECK-NEXT: ^bb7: // pred: ^bb5
|
||||
// CHECK-NEXT: llvm.call @post({{.*}}) : (i64) -> ()
|
||||
|
@ -252,7 +252,7 @@ func @imperfectly_nested_loops() {
|
|||
call @post(%0) : (index) -> ()
|
||||
%c1 = arith.constant 1 : index
|
||||
%5 = arith.addi %0, %c1 : index
|
||||
br ^bb2(%5 : index)
|
||||
cf.br ^bb2(%5 : index)
|
||||
|
||||
// CHECK-NEXT: ^bb8: // pred: ^bb2
|
||||
// CHECK-NEXT: llvm.return
|
||||
|
@ -316,49 +316,49 @@ func private @body3(index, index)
|
|||
// CHECK-NEXT: }
|
||||
func @more_imperfectly_nested_loops() {
|
||||
^bb0:
|
||||
br ^bb1
|
||||
cf.br ^bb1
|
||||
^bb1: // pred: ^bb0
|
||||
%c0 = arith.constant 0 : index
|
||||
%c42 = arith.constant 42 : index
|
||||
br ^bb2(%c0 : index)
|
||||
cf.br ^bb2(%c0 : index)
|
||||
^bb2(%0: index): // 2 preds: ^bb1, ^bb11
|
||||
%1 = arith.cmpi slt, %0, %c42 : index
|
||||
cond_br %1, ^bb3, ^bb12
|
||||
cf.cond_br %1, ^bb3, ^bb12
|
||||
^bb3: // pred: ^bb2
|
||||
call @pre(%0) : (index) -> ()
|
||||
br ^bb4
|
||||
cf.br ^bb4
|
||||
^bb4: // pred: ^bb3
|
||||
%c7 = arith.constant 7 : index
|
||||
%c56 = arith.constant 56 : index
|
||||
br ^bb5(%c7 : index)
|
||||
cf.br ^bb5(%c7 : index)
|
||||
^bb5(%2: index): // 2 preds: ^bb4, ^bb6
|
||||
%3 = arith.cmpi slt, %2, %c56 : index
|
||||
cond_br %3, ^bb6, ^bb7
|
||||
cf.cond_br %3, ^bb6, ^bb7
|
||||
^bb6: // pred: ^bb5
|
||||
call @body2(%0, %2) : (index, index) -> ()
|
||||
%c2 = arith.constant 2 : index
|
||||
%4 = arith.addi %2, %c2 : index
|
||||
br ^bb5(%4 : index)
|
||||
cf.br ^bb5(%4 : index)
|
||||
^bb7: // pred: ^bb5
|
||||
call @mid(%0) : (index) -> ()
|
||||
br ^bb8
|
||||
cf.br ^bb8
|
||||
^bb8: // pred: ^bb7
|
||||
%c18 = arith.constant 18 : index
|
||||
%c37 = arith.constant 37 : index
|
||||
br ^bb9(%c18 : index)
|
||||
cf.br ^bb9(%c18 : index)
|
||||
^bb9(%5: index): // 2 preds: ^bb8, ^bb10
|
||||
%6 = arith.cmpi slt, %5, %c37 : index
|
||||
cond_br %6, ^bb10, ^bb11
|
||||
cf.cond_br %6, ^bb10, ^bb11
|
||||
^bb10: // pred: ^bb9
|
||||
call @body3(%0, %5) : (index, index) -> ()
|
||||
%c3 = arith.constant 3 : index
|
||||
%7 = arith.addi %5, %c3 : index
|
||||
br ^bb9(%7 : index)
|
||||
cf.br ^bb9(%7 : index)
|
||||
^bb11: // pred: ^bb9
|
||||
call @post(%0) : (index) -> ()
|
||||
%c1 = arith.constant 1 : index
|
||||
%8 = arith.addi %0, %c1 : index
|
||||
br ^bb2(%8 : index)
|
||||
cf.br ^bb2(%8 : index)
|
||||
^bb12: // pred: ^bb2
|
||||
return
|
||||
}
|
||||
|
@ -432,7 +432,7 @@ func @dfs_block_order(%arg0: i32) -> (i32) {
|
|||
// CHECK-NEXT: %[[CST:.*]] = llvm.mlir.constant(42 : i32) : i32
|
||||
%0 = arith.constant 42 : i32
|
||||
// CHECK-NEXT: llvm.br ^bb2
|
||||
br ^bb2
|
||||
cf.br ^bb2
|
||||
|
||||
// CHECK-NEXT: ^bb1:
|
||||
// CHECK-NEXT: %[[ADD:.*]] = llvm.add %arg0, %[[CST]] : i32
|
||||
|
@ -444,7 +444,7 @@ func @dfs_block_order(%arg0: i32) -> (i32) {
|
|||
// CHECK-NEXT: ^bb2:
|
||||
^bb2:
|
||||
// CHECK-NEXT: llvm.br ^bb1
|
||||
br ^bb1
|
||||
cf.br ^bb1
|
||||
}
|
||||
|
||||
// -----
|
||||
|
@ -469,7 +469,7 @@ func @floorf(%arg0 : f32) {
|
|||
|
||||
// -----
|
||||
|
||||
// Lowers `assert` to a function call to `abort` if the assertion is violated.
|
||||
// Lowers `cf.assert` to a function call to `abort` if the assertion is violated.
|
||||
// CHECK: llvm.func @abort()
|
||||
// CHECK-LABEL: @assert_test_function
|
||||
// CHECK-SAME: (%[[ARG:.*]]: i1)
|
||||
|
@ -480,7 +480,7 @@ func @assert_test_function(%arg : i1) {
|
|||
// CHECK: ^[[FAILURE_BLOCK]]:
|
||||
// CHECK: llvm.call @abort() : () -> ()
|
||||
// CHECK: llvm.unreachable
|
||||
assert %arg, "Computer says no"
|
||||
cf.assert %arg, "Computer says no"
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -514,8 +514,8 @@ func @fmaf(%arg0: f32, %arg1: vector<4xf32>) {
|
|||
|
||||
// CHECK-LABEL: func @switchi8(
|
||||
func @switchi8(%arg0 : i8) -> i32 {
|
||||
switch %arg0 : i8, [
|
||||
default: ^bb1,
|
||||
cf.switch %arg0 : i8, [
|
||||
default: ^bb1,
|
||||
42: ^bb1,
|
||||
43: ^bb3
|
||||
]
|
||||
|
|
|
@ -900,45 +900,3 @@ func @tensor_extract_constant(%a : index, %b: index, %c: index) -> i32 {
|
|||
// CHECK: spv.ReturnValue %[[VAL]]
|
||||
return %extract : i32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// std.br, std.cond_br
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
module attributes {
|
||||
spv.target_env = #spv.target_env<#spv.vce<v1.0, [], []>, {}>
|
||||
} {
|
||||
|
||||
// CHECK-LABEL: func @simple_loop
|
||||
func @simple_loop(index, index, index) {
|
||||
^bb0(%begin : index, %end : index, %step : index):
|
||||
// CHECK-NEXT: spv.Branch ^bb1
|
||||
br ^bb1
|
||||
|
||||
// CHECK-NEXT: ^bb1: // pred: ^bb0
|
||||
// CHECK-NEXT: spv.Branch ^bb2({{.*}} : i32)
|
||||
^bb1: // pred: ^bb0
|
||||
br ^bb2(%begin : index)
|
||||
|
||||
// CHECK: ^bb2({{.*}}: i32): // 2 preds: ^bb1, ^bb3
|
||||
// CHECK-NEXT: {{.*}} = spv.SLessThan {{.*}}, {{.*}} : i32
|
||||
// CHECK-NEXT: spv.BranchConditional {{.*}}, ^bb3, ^bb4
|
||||
^bb2(%0: index): // 2 preds: ^bb1, ^bb3
|
||||
%1 = arith.cmpi slt, %0, %end : index
|
||||
cond_br %1, ^bb3, ^bb4
|
||||
|
||||
// CHECK: ^bb3: // pred: ^bb2
|
||||
// CHECK-NEXT: {{.*}} = spv.IAdd {{.*}}, {{.*}} : i32
|
||||
// CHECK-NEXT: spv.Branch ^bb2({{.*}} : i32)
|
||||
^bb3: // pred: ^bb2
|
||||
%2 = arith.addi %0, %step : index
|
||||
br ^bb2(%2 : index)
|
||||
|
||||
// CHECK: ^bb4: // pred: ^bb2
|
||||
^bb4: // pred: ^bb2
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -56,9 +56,9 @@ func @affine_load_invalid_dim(%M : memref<10xi32>) {
|
|||
^bb0(%arg: index):
|
||||
affine.load %M[%arg] : memref<10xi32>
|
||||
// expected-error@-1 {{index must be a dimension or symbol identifier}}
|
||||
br ^bb1
|
||||
cf.br ^bb1
|
||||
^bb1:
|
||||
br ^bb1
|
||||
cf.br ^bb1
|
||||
}) : () -> ()
|
||||
return
|
||||
}
|
||||
|
|
|
@ -54,13 +54,13 @@ func @token_value_to_func() {
|
|||
// CHECK-LABEL: @token_arg_cond_br_await_with_fallthough
|
||||
// CHECK: %[[TOKEN:.*]]: !async.token
|
||||
func @token_arg_cond_br_await_with_fallthough(%arg0: !async.token, %arg1: i1) {
|
||||
// CHECK: cond_br
|
||||
// CHECK: cf.cond_br
|
||||
// CHECK-SAME: ^[[BB1:.*]], ^[[BB2:.*]]
|
||||
cond_br %arg1, ^bb1, ^bb2
|
||||
cf.cond_br %arg1, ^bb1, ^bb2
|
||||
^bb1:
|
||||
// CHECK: ^[[BB1]]:
|
||||
// CHECK: br ^[[BB2]]
|
||||
br ^bb2
|
||||
// CHECK: cf.br ^[[BB2]]
|
||||
cf.br ^bb2
|
||||
^bb2:
|
||||
// CHECK: ^[[BB2]]:
|
||||
// CHECK: async.runtime.await %[[TOKEN]]
|
||||
|
@ -88,10 +88,10 @@ func @token_coro_return() -> !async.token {
|
|||
async.runtime.resume %hdl
|
||||
async.coro.suspend %saved, ^suspend, ^resume, ^cleanup
|
||||
^resume:
|
||||
br ^cleanup
|
||||
cf.br ^cleanup
|
||||
^cleanup:
|
||||
async.coro.free %id, %hdl
|
||||
br ^suspend
|
||||
cf.br ^suspend
|
||||
^suspend:
|
||||
async.coro.end %hdl
|
||||
return %token : !async.token
|
||||
|
@ -109,10 +109,10 @@ func @token_coro_await_and_resume(%arg0: !async.token) -> !async.token {
|
|||
// CHECK-NEXT: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i64}
|
||||
async.coro.suspend %saved, ^suspend, ^resume, ^cleanup
|
||||
^resume:
|
||||
br ^cleanup
|
||||
cf.br ^cleanup
|
||||
^cleanup:
|
||||
async.coro.free %id, %hdl
|
||||
br ^suspend
|
||||
cf.br ^suspend
|
||||
^suspend:
|
||||
async.coro.end %hdl
|
||||
return %token : !async.token
|
||||
|
@ -137,10 +137,10 @@ func @value_coro_await_and_resume(%arg0: !async.value<f32>) -> !async.token {
|
|||
%0 = async.runtime.load %arg0 : !async.value<f32>
|
||||
// CHECK: arith.addf %[[LOADED]], %[[LOADED]]
|
||||
%1 = arith.addf %0, %0 : f32
|
||||
br ^cleanup
|
||||
cf.br ^cleanup
|
||||
^cleanup:
|
||||
async.coro.free %id, %hdl
|
||||
br ^suspend
|
||||
cf.br ^suspend
|
||||
^suspend:
|
||||
async.coro.end %hdl
|
||||
return %token : !async.token
|
||||
|
@ -167,12 +167,12 @@ func private @outlined_async_execute(%arg0: !async.token) -> !async.token {
|
|||
// CHECK: ^[[RESUME_1:.*]]:
|
||||
// CHECK: async.runtime.set_available
|
||||
async.runtime.set_available %0 : !async.token
|
||||
br ^cleanup
|
||||
cf.br ^cleanup
|
||||
^cleanup:
|
||||
// CHECK: ^[[CLEANUP:.*]]:
|
||||
// CHECK: async.coro.free
|
||||
async.coro.free %1, %2
|
||||
br ^suspend
|
||||
cf.br ^suspend
|
||||
^suspend:
|
||||
// CHECK: ^[[SUSPEND:.*]]:
|
||||
// CHECK: async.coro.end
|
||||
|
@ -198,7 +198,7 @@ func @token_await_inside_nested_region(%arg0: i1) {
|
|||
|
||||
// CHECK-LABEL: @token_defined_in_the_loop
|
||||
func @token_defined_in_the_loop() {
|
||||
br ^bb1
|
||||
cf.br ^bb1
|
||||
^bb1:
|
||||
// CHECK: ^[[BB1:.*]]:
|
||||
// CHECK: %[[TOKEN:.*]] = call @token()
|
||||
|
@ -207,7 +207,7 @@ func @token_defined_in_the_loop() {
|
|||
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i64}
|
||||
async.runtime.await %token : !async.token
|
||||
%0 = call @cond(): () -> (i1)
|
||||
cond_br %0, ^bb1, ^bb2
|
||||
cf.cond_br %0, ^bb1, ^bb2
|
||||
^bb2:
|
||||
// CHECK: ^[[BB2:.*]]:
|
||||
// CHECK: return
|
||||
|
@ -218,18 +218,18 @@ func @token_defined_in_the_loop() {
|
|||
func @divergent_liveness_one_token(%arg0 : i1) {
|
||||
// CHECK: %[[TOKEN:.*]] = call @token()
|
||||
%token = call @token() : () -> !async.token
|
||||
// CHECK: cond_br %arg0, ^[[LIVE_IN:.*]], ^[[REF_COUNTING:.*]]
|
||||
cond_br %arg0, ^bb1, ^bb2
|
||||
// CHECK: cf.cond_br %arg0, ^[[LIVE_IN:.*]], ^[[REF_COUNTING:.*]]
|
||||
cf.cond_br %arg0, ^bb1, ^bb2
|
||||
^bb1:
|
||||
// CHECK: ^[[LIVE_IN]]:
|
||||
// CHECK: async.runtime.await %[[TOKEN]]
|
||||
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i64}
|
||||
// CHECK: br ^[[RETURN:.*]]
|
||||
// CHECK: cf.br ^[[RETURN:.*]]
|
||||
async.runtime.await %token : !async.token
|
||||
br ^bb2
|
||||
cf.br ^bb2
|
||||
// CHECK: ^[[REF_COUNTING:.*]]:
|
||||
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i64}
|
||||
// CHECK: br ^[[RETURN:.*]]
|
||||
// CHECK: cf.br ^[[RETURN:.*]]
|
||||
^bb2:
|
||||
// CHECK: ^[[RETURN]]:
|
||||
// CHECK: return
|
||||
|
@ -240,20 +240,20 @@ func @divergent_liveness_one_token(%arg0 : i1) {
|
|||
func @divergent_liveness_unique_predecessor(%arg0 : i1) {
|
||||
// CHECK: %[[TOKEN:.*]] = call @token()
|
||||
%token = call @token() : () -> !async.token
|
||||
// CHECK: cond_br %arg0, ^[[LIVE_IN:.*]], ^[[NO_LIVE_IN:.*]]
|
||||
cond_br %arg0, ^bb2, ^bb1
|
||||
// CHECK: cf.cond_br %arg0, ^[[LIVE_IN:.*]], ^[[NO_LIVE_IN:.*]]
|
||||
cf.cond_br %arg0, ^bb2, ^bb1
|
||||
^bb1:
|
||||
// CHECK: ^[[NO_LIVE_IN]]:
|
||||
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i64}
|
||||
// CHECK: br ^[[RETURN:.*]]
|
||||
br ^bb3
|
||||
// CHECK: cf.br ^[[RETURN:.*]]
|
||||
cf.br ^bb3
|
||||
^bb2:
|
||||
// CHECK: ^[[LIVE_IN]]:
|
||||
// CHECK: async.runtime.await %[[TOKEN]]
|
||||
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i64}
|
||||
// CHECK: br ^[[RETURN]]
|
||||
// CHECK: cf.br ^[[RETURN]]
|
||||
async.runtime.await %token : !async.token
|
||||
br ^bb3
|
||||
cf.br ^bb3
|
||||
^bb3:
|
||||
// CHECK: ^[[RETURN]]:
|
||||
// CHECK: return
|
||||
|
@ -266,24 +266,24 @@ func @divergent_liveness_two_tokens(%arg0 : i1) {
|
|||
// CHECK: %[[TOKEN1:.*]] = call @token()
|
||||
%token0 = call @token() : () -> !async.token
|
||||
%token1 = call @token() : () -> !async.token
|
||||
// CHECK: cond_br %arg0, ^[[AWAIT0:.*]], ^[[AWAIT1:.*]]
|
||||
cond_br %arg0, ^await0, ^await1
|
||||
// CHECK: cf.cond_br %arg0, ^[[AWAIT0:.*]], ^[[AWAIT1:.*]]
|
||||
cf.cond_br %arg0, ^await0, ^await1
|
||||
^await0:
|
||||
// CHECK: ^[[AWAIT0]]:
|
||||
// CHECK: async.runtime.drop_ref %[[TOKEN1]] {count = 1 : i64}
|
||||
// CHECK: async.runtime.await %[[TOKEN0]]
|
||||
// CHECK: async.runtime.drop_ref %[[TOKEN0]] {count = 1 : i64}
|
||||
// CHECK: br ^[[RETURN:.*]]
|
||||
// CHECK: cf.br ^[[RETURN:.*]]
|
||||
async.runtime.await %token0 : !async.token
|
||||
br ^ret
|
||||
cf.br ^ret
|
||||
^await1:
|
||||
// CHECK: ^[[AWAIT1]]:
|
||||
// CHECK: async.runtime.drop_ref %[[TOKEN0]] {count = 1 : i64}
|
||||
// CHECK: async.runtime.await %[[TOKEN1]]
|
||||
// CHECK: async.runtime.drop_ref %[[TOKEN1]] {count = 1 : i64}
|
||||
// CHECK: br ^[[RETURN]]
|
||||
// CHECK: cf.br ^[[RETURN]]
|
||||
async.runtime.await %token1 : !async.token
|
||||
br ^ret
|
||||
cf.br ^ret
|
||||
^ret:
|
||||
// CHECK: ^[[RETURN]]:
|
||||
// CHECK: return
|
||||
|
|
|
@ -10,7 +10,7 @@ func @simple_callee(%arg0: f32) -> (f32 {builtin.foo = "bar"}) {
|
|||
// CHECK: %[[RETURNED_STORAGE:.*]] = async.runtime.create : !async.value<f32>
|
||||
// CHECK: %[[ID:.*]] = async.coro.id
|
||||
// CHECK: %[[HDL:.*]] = async.coro.begin %[[ID]]
|
||||
// CHECK: br ^[[ORIGINAL_ENTRY:.*]]
|
||||
// CHECK: cf.br ^[[ORIGINAL_ENTRY:.*]]
|
||||
// CHECK ^[[ORIGINAL_ENTRY]]:
|
||||
// CHECK: %[[VAL:.*]] = arith.addf %[[ARG]], %[[ARG]] : f32
|
||||
%0 = arith.addf %arg0, %arg0 : f32
|
||||
|
@ -29,7 +29,7 @@ func @simple_callee(%arg0: f32) -> (f32 {builtin.foo = "bar"}) {
|
|||
|
||||
// CHECK: ^[[RESUME]]:
|
||||
// CHECK: %[[IS_ERROR:.*]] = async.runtime.is_error %[[VAL_STORAGE]] : !async.value<f32>
|
||||
// CHECK: cond_br %[[IS_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_OK:.*]]
|
||||
// CHECK: cf.cond_br %[[IS_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_OK:.*]]
|
||||
|
||||
// CHECK: ^[[BRANCH_OK]]:
|
||||
// CHECK: %[[LOADED:.*]] = async.runtime.load %[[VAL_STORAGE]] : <f32>
|
||||
|
@ -37,19 +37,19 @@ func @simple_callee(%arg0: f32) -> (f32 {builtin.foo = "bar"}) {
|
|||
// CHECK: async.runtime.store %[[RETURNED]], %[[RETURNED_STORAGE]] : <f32>
|
||||
// CHECK: async.runtime.set_available %[[RETURNED_STORAGE]]
|
||||
// CHECK: async.runtime.set_available %[[TOKEN]]
|
||||
// CHECK: br ^[[CLEANUP]]
|
||||
// CHECK: cf.br ^[[CLEANUP]]
|
||||
%3 = arith.mulf %arg0, %2 : f32
|
||||
return %3: f32
|
||||
|
||||
// CHECK: ^[[BRANCH_ERROR]]:
|
||||
// CHECK: async.runtime.set_error %[[TOKEN]]
|
||||
// CHECK: async.runtime.set_error %[[RETURNED_STORAGE]]
|
||||
// CHECK: br ^[[CLEANUP]]
|
||||
// CHECK: cf.br ^[[CLEANUP]]
|
||||
|
||||
|
||||
// CHECK: ^[[CLEANUP]]:
|
||||
// CHECK: async.coro.free %[[ID]], %[[HDL]]
|
||||
// CHECK: br ^[[SUSPEND]]
|
||||
// CHECK: cf.br ^[[SUSPEND]]
|
||||
|
||||
// CHECK: ^[[SUSPEND]]:
|
||||
// CHECK: async.coro.end %[[HDL]]
|
||||
|
@ -63,7 +63,7 @@ func @simple_caller() -> f32 {
|
|||
// CHECK: %[[RETURNED_STORAGE:.*]] = async.runtime.create : !async.value<f32>
|
||||
// CHECK: %[[ID:.*]] = async.coro.id
|
||||
// CHECK: %[[HDL:.*]] = async.coro.begin %[[ID]]
|
||||
// CHECK: br ^[[ORIGINAL_ENTRY:.*]]
|
||||
// CHECK: cf.br ^[[ORIGINAL_ENTRY:.*]]
|
||||
// CHECK ^[[ORIGINAL_ENTRY]]:
|
||||
|
||||
// CHECK: %[[CONSTANT:.*]] = arith.constant
|
||||
|
@ -77,28 +77,28 @@ func @simple_caller() -> f32 {
|
|||
|
||||
// CHECK: ^[[RESUME]]:
|
||||
// CHECK: %[[IS_TOKEN_ERROR:.*]] = async.runtime.is_error %[[RETURNED_TO_CALLER]]#0 : !async.token
|
||||
// CHECK: cond_br %[[IS_TOKEN_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_TOKEN_OK:.*]]
|
||||
// CHECK: cf.cond_br %[[IS_TOKEN_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_TOKEN_OK:.*]]
|
||||
|
||||
// CHECK: ^[[BRANCH_TOKEN_OK]]:
|
||||
// CHECK: %[[IS_VALUE_ERROR:.*]] = async.runtime.is_error %[[RETURNED_TO_CALLER]]#1 : !async.value<f32>
|
||||
// CHECK: cond_br %[[IS_VALUE_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_VALUE_OK:.*]]
|
||||
// CHECK: cf.cond_br %[[IS_VALUE_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_VALUE_OK:.*]]
|
||||
|
||||
// CHECK: ^[[BRANCH_VALUE_OK]]:
|
||||
// CHECK: %[[LOADED:.*]] = async.runtime.load %[[RETURNED_TO_CALLER]]#1 : <f32>
|
||||
// CHECK: async.runtime.store %[[LOADED]], %[[RETURNED_STORAGE]] : <f32>
|
||||
// CHECK: async.runtime.set_available %[[RETURNED_STORAGE]]
|
||||
// CHECK: async.runtime.set_available %[[TOKEN]]
|
||||
// CHECK: br ^[[CLEANUP]]
|
||||
// CHECK: cf.br ^[[CLEANUP]]
|
||||
return %r: f32
|
||||
// CHECK: ^[[BRANCH_ERROR]]:
|
||||
// CHECK: async.runtime.set_error %[[TOKEN]]
|
||||
// CHECK: async.runtime.set_error %[[RETURNED_STORAGE]]
|
||||
// CHECK: br ^[[CLEANUP]]
|
||||
// CHECK: cf.br ^[[CLEANUP]]
|
||||
|
||||
|
||||
// CHECK: ^[[CLEANUP]]:
|
||||
// CHECK: async.coro.free %[[ID]], %[[HDL]]
|
||||
// CHECK: br ^[[SUSPEND]]
|
||||
// CHECK: cf.br ^[[SUSPEND]]
|
||||
|
||||
// CHECK: ^[[SUSPEND]]:
|
||||
// CHECK: async.coro.end %[[HDL]]
|
||||
|
@ -112,7 +112,7 @@ func @double_caller() -> f32 {
|
|||
// CHECK: %[[RETURNED_STORAGE:.*]] = async.runtime.create : !async.value<f32>
|
||||
// CHECK: %[[ID:.*]] = async.coro.id
|
||||
// CHECK: %[[HDL:.*]] = async.coro.begin %[[ID]]
|
||||
// CHECK: br ^[[ORIGINAL_ENTRY:.*]]
|
||||
// CHECK: cf.br ^[[ORIGINAL_ENTRY:.*]]
|
||||
// CHECK ^[[ORIGINAL_ENTRY]]:
|
||||
|
||||
// CHECK: %[[CONSTANT:.*]] = arith.constant
|
||||
|
@ -126,11 +126,11 @@ func @double_caller() -> f32 {
|
|||
|
||||
// CHECK: ^[[RESUME_1]]:
|
||||
// CHECK: %[[IS_TOKEN_ERROR_1:.*]] = async.runtime.is_error %[[RETURNED_TO_CALLER_1]]#0 : !async.token
|
||||
// CHECK: cond_br %[[IS_TOKEN_ERROR_1]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_TOKEN_OK_1:.*]]
|
||||
// CHECK: cf.cond_br %[[IS_TOKEN_ERROR_1]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_TOKEN_OK_1:.*]]
|
||||
|
||||
// CHECK: ^[[BRANCH_TOKEN_OK_1]]:
|
||||
// CHECK: %[[IS_VALUE_ERROR_1:.*]] = async.runtime.is_error %[[RETURNED_TO_CALLER_1]]#1 : !async.value<f32>
|
||||
// CHECK: cond_br %[[IS_VALUE_ERROR_1]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_VALUE_OK_1:.*]]
|
||||
// CHECK: cf.cond_br %[[IS_VALUE_ERROR_1]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_VALUE_OK_1:.*]]
|
||||
|
||||
// CHECK: ^[[BRANCH_VALUE_OK_1]]:
|
||||
// CHECK: %[[LOADED_1:.*]] = async.runtime.load %[[RETURNED_TO_CALLER_1]]#1 : <f32>
|
||||
|
@ -143,27 +143,27 @@ func @double_caller() -> f32 {
|
|||
|
||||
// CHECK: ^[[RESUME_2]]:
|
||||
// CHECK: %[[IS_TOKEN_ERROR_2:.*]] = async.runtime.is_error %[[RETURNED_TO_CALLER_2]]#0 : !async.token
|
||||
// CHECK: cond_br %[[IS_TOKEN_ERROR_2]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_TOKEN_OK_2:.*]]
|
||||
// CHECK: cf.cond_br %[[IS_TOKEN_ERROR_2]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_TOKEN_OK_2:.*]]
|
||||
|
||||
// CHECK: ^[[BRANCH_TOKEN_OK_2]]:
|
||||
// CHECK: %[[IS_VALUE_ERROR_2:.*]] = async.runtime.is_error %[[RETURNED_TO_CALLER_2]]#1 : !async.value<f32>
|
||||
// CHECK: cond_br %[[IS_VALUE_ERROR_2]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_VALUE_OK_2:.*]]
|
||||
// CHECK: cf.cond_br %[[IS_VALUE_ERROR_2]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_VALUE_OK_2:.*]]
|
||||
|
||||
// CHECK: ^[[BRANCH_VALUE_OK_2]]:
|
||||
// CHECK: %[[LOADED_2:.*]] = async.runtime.load %[[RETURNED_TO_CALLER_2]]#1 : <f32>
|
||||
// CHECK: async.runtime.store %[[LOADED_2]], %[[RETURNED_STORAGE]] : <f32>
|
||||
// CHECK: async.runtime.set_available %[[RETURNED_STORAGE]]
|
||||
// CHECK: async.runtime.set_available %[[TOKEN]]
|
||||
// CHECK: br ^[[CLEANUP]]
|
||||
// CHECK: cf.br ^[[CLEANUP]]
|
||||
return %s: f32
|
||||
// CHECK: ^[[BRANCH_ERROR]]:
|
||||
// CHECK: async.runtime.set_error %[[TOKEN]]
|
||||
// CHECK: async.runtime.set_error %[[RETURNED_STORAGE]]
|
||||
// CHECK: br ^[[CLEANUP]]
|
||||
// CHECK: cf.br ^[[CLEANUP]]
|
||||
|
||||
// CHECK: ^[[CLEANUP]]:
|
||||
// CHECK: async.coro.free %[[ID]], %[[HDL]]
|
||||
// CHECK: br ^[[SUSPEND]]
|
||||
// CHECK: cf.br ^[[SUSPEND]]
|
||||
|
||||
// CHECK: ^[[SUSPEND]]:
|
||||
// CHECK: async.coro.end %[[HDL]]
|
||||
|
@ -184,7 +184,7 @@ func @recursive(%arg: !async.token) {
|
|||
async.await %arg : !async.token
|
||||
// CHECK: ^[[RESUME_1]]:
|
||||
// CHECK: %[[IS_ERROR:.*]] = async.runtime.is_error %[[ARG]] : !async.token
|
||||
// CHECK: cond_br %[[IS_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_OK:.*]]
|
||||
// CHECK: cf.cond_br %[[IS_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_OK:.*]]
|
||||
|
||||
// CHECK: ^[[BRANCH_OK]]:
|
||||
// CHECK: %[[GIVEN:.*]] = async.runtime.create : !async.token
|
||||
|
@ -200,16 +200,16 @@ call @recursive(%r): (!async.token) -> ()
|
|||
|
||||
// CHECK: ^[[RESUME_2]]:
|
||||
// CHECK: async.runtime.set_available %[[TOKEN]]
|
||||
// CHECK: br ^[[CLEANUP]]
|
||||
// CHECK: cf.br ^[[CLEANUP]]
|
||||
|
||||
// CHECK: ^[[BRANCH_ERROR]]:
|
||||
// CHECK: async.runtime.set_error %[[TOKEN]]
|
||||
// CHECK: br ^[[CLEANUP]]
|
||||
// CHECK: cf.br ^[[CLEANUP]]
|
||||
return
|
||||
|
||||
// CHECK: ^[[CLEANUP]]:
|
||||
// CHECK: async.coro.free %[[ID]], %[[HDL]]
|
||||
// CHECK: br ^[[SUSPEND]]
|
||||
// CHECK: cf.br ^[[SUSPEND]]
|
||||
|
||||
// CHECK: ^[[SUSPEND]]:
|
||||
// CHECK: async.coro.end %[[HDL]]
|
||||
|
@ -230,7 +230,7 @@ func @corecursive1(%arg: !async.token) {
|
|||
async.await %arg : !async.token
|
||||
// CHECK: ^[[RESUME_1]]:
|
||||
// CHECK: %[[IS_ERROR:.*]] = async.runtime.is_error %[[ARG]] : !async.token
|
||||
// CHECK: cond_br %[[IS_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_OK:.*]]
|
||||
// CHECK: cf.cond_br %[[IS_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_OK:.*]]
|
||||
|
||||
// CHECK: ^[[BRANCH_OK]]:
|
||||
// CHECK: %[[GIVEN:.*]] = async.runtime.create : !async.token
|
||||
|
@ -246,16 +246,16 @@ call @corecursive2(%r): (!async.token) -> ()
|
|||
|
||||
// CHECK: ^[[RESUME_2]]:
|
||||
// CHECK: async.runtime.set_available %[[TOKEN]]
|
||||
// CHECK: br ^[[CLEANUP]]
|
||||
// CHECK: cf.br ^[[CLEANUP]]
|
||||
|
||||
// CHECK: ^[[BRANCH_ERROR]]:
|
||||
// CHECK: async.runtime.set_error %[[TOKEN]]
|
||||
// CHECK: br ^[[CLEANUP]]
|
||||
// CHECK: cf.br ^[[CLEANUP]]
|
||||
return
|
||||
|
||||
// CHECK: ^[[CLEANUP]]:
|
||||
// CHECK: async.coro.free %[[ID]], %[[HDL]]
|
||||
// CHECK: br ^[[SUSPEND]]
|
||||
// CHECK: cf.br ^[[SUSPEND]]
|
||||
|
||||
// CHECK: ^[[SUSPEND]]:
|
||||
// CHECK: async.coro.end %[[HDL]]
|
||||
|
@ -276,7 +276,7 @@ func @corecursive2(%arg: !async.token) {
|
|||
async.await %arg : !async.token
|
||||
// CHECK: ^[[RESUME_1]]:
|
||||
// CHECK: %[[IS_ERROR:.*]] = async.runtime.is_error %[[ARG]] : !async.token
|
||||
// CHECK: cond_br %[[IS_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_OK:.*]]
|
||||
// CHECK: cf.cond_br %[[IS_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_OK:.*]]
|
||||
|
||||
// CHECK: ^[[BRANCH_OK]]:
|
||||
// CHECK: %[[GIVEN:.*]] = async.runtime.create : !async.token
|
||||
|
@ -292,16 +292,16 @@ call @corecursive1(%r): (!async.token) -> ()
|
|||
|
||||
// CHECK: ^[[RESUME_2]]:
|
||||
// CHECK: async.runtime.set_available %[[TOKEN]]
|
||||
// CHECK: br ^[[CLEANUP]]
|
||||
// CHECK: cf.br ^[[CLEANUP]]
|
||||
|
||||
// CHECK: ^[[BRANCH_ERROR]]:
|
||||
// CHECK: async.runtime.set_error %[[TOKEN]]
|
||||
// CHECK: br ^[[CLEANUP]]
|
||||
// CHECK: cf.br ^[[CLEANUP]]
|
||||
return
|
||||
|
||||
// CHECK: ^[[CLEANUP]]:
|
||||
// CHECK: async.coro.free %[[ID]], %[[HDL]]
|
||||
// CHECK: br ^[[SUSPEND]]
|
||||
// CHECK: cf.br ^[[SUSPEND]]
|
||||
|
||||
// CHECK: ^[[SUSPEND]]:
|
||||
// CHECK: async.coro.end %[[HDL]]
|
||||
|
|
|
@ -63,7 +63,7 @@ func @nested_async_execute(%arg0: f32, %arg1: f32, %arg2: memref<1xf32>) {
|
|||
// CHECK: %[[IS_ERROR:.*]] = async.runtime.is_error %[[TOKEN]]
|
||||
// CHECK: %[[TRUE:.*]] = arith.constant true
|
||||
// CHECK: %[[NOT_ERROR:.*]] = arith.xori %[[IS_ERROR]], %[[TRUE]] : i1
|
||||
// CHECK: assert %[[NOT_ERROR]]
|
||||
// CHECK: cf.assert %[[NOT_ERROR]]
|
||||
// CHECK-NEXT: return
|
||||
async.await %token0 : !async.token
|
||||
return
|
||||
|
@ -109,7 +109,7 @@ func @nested_async_execute(%arg0: f32, %arg1: f32, %arg2: memref<1xf32>) {
|
|||
// Check the error of the awaited token after resumption.
|
||||
// CHECK: ^[[RESUME_1]]:
|
||||
// CHECK: %[[ERR:.*]] = async.runtime.is_error %[[INNER_TOKEN]]
|
||||
// CHECK: cond_br %[[ERR]], ^[[SET_ERROR:.*]], ^[[CONTINUATION:.*]]
|
||||
// CHECK: cf.cond_br %[[ERR]], ^[[SET_ERROR:.*]], ^[[CONTINUATION:.*]]
|
||||
|
||||
// Set token available if the token is not in the error state.
|
||||
// CHECK: ^[[CONTINUATION:.*]]:
|
||||
|
@ -169,7 +169,7 @@ func @async_execute_token_dependency(%arg0: f32, %arg1: memref<1xf32>) {
|
|||
// Check the error of the awaited token after resumption.
|
||||
// CHECK: ^[[RESUME_1]]:
|
||||
// CHECK: %[[ERR:.*]] = async.runtime.is_error %[[ARG0]]
|
||||
// CHECK: cond_br %[[ERR]], ^[[SET_ERROR:.*]], ^[[CONTINUATION:.*]]
|
||||
// CHECK: cf.cond_br %[[ERR]], ^[[SET_ERROR:.*]], ^[[CONTINUATION:.*]]
|
||||
|
||||
// Emplace result token after second resumption and error checking.
|
||||
// CHECK: ^[[CONTINUATION:.*]]:
|
||||
|
@ -225,7 +225,7 @@ func @async_group_await_all(%arg0: f32, %arg1: memref<1xf32>) {
|
|||
// Check the error of the awaited token after resumption.
|
||||
// CHECK: ^[[RESUME_1]]:
|
||||
// CHECK: %[[ERR:.*]] = async.runtime.is_error %[[ARG]]
|
||||
// CHECK: cond_br %[[ERR]], ^[[SET_ERROR:.*]], ^[[CONTINUATION:.*]]
|
||||
// CHECK: cf.cond_br %[[ERR]], ^[[SET_ERROR:.*]], ^[[CONTINUATION:.*]]
|
||||
|
||||
// Emplace result token after error checking.
|
||||
// CHECK: ^[[CONTINUATION:.*]]:
|
||||
|
@ -319,7 +319,7 @@ func @async_value_operands() {
|
|||
// Check the error of the awaited token after resumption.
|
||||
// CHECK: ^[[RESUME_1]]:
|
||||
// CHECK: %[[ERR:.*]] = async.runtime.is_error %[[ARG]]
|
||||
// CHECK: cond_br %[[ERR]], ^[[SET_ERROR:.*]], ^[[CONTINUATION:.*]]
|
||||
// CHECK: cf.cond_br %[[ERR]], ^[[SET_ERROR:.*]], ^[[CONTINUATION:.*]]
|
||||
|
||||
// // Load from the async.value argument after error checking.
|
||||
// CHECK: ^[[CONTINUATION:.*]]:
|
||||
|
@ -335,7 +335,7 @@ func @async_value_operands() {
|
|||
// CHECK-LABEL: @execute_assertion
|
||||
func @execute_assertion(%arg0: i1) {
|
||||
%token = async.execute {
|
||||
assert %arg0, "error"
|
||||
cf.assert %arg0, "error"
|
||||
async.yield
|
||||
}
|
||||
async.await %token : !async.token
|
||||
|
@ -358,17 +358,17 @@ func @execute_assertion(%arg0: i1) {
|
|||
|
||||
// Resume coroutine after suspension.
|
||||
// CHECK: ^[[RESUME]]:
|
||||
// CHECK: cond_br %[[ARG0]], ^[[SET_AVAILABLE:.*]], ^[[SET_ERROR:.*]]
|
||||
// CHECK: cf.cond_br %[[ARG0]], ^[[SET_AVAILABLE:.*]], ^[[SET_ERROR:.*]]
|
||||
|
||||
// Set coroutine completion token to available state.
|
||||
// CHECK: ^[[SET_AVAILABLE]]:
|
||||
// CHECK: async.runtime.set_available %[[TOKEN]]
|
||||
// CHECK: br ^[[CLEANUP]]
|
||||
// CHECK: cf.br ^[[CLEANUP]]
|
||||
|
||||
// Set coroutine completion token to error state.
|
||||
// CHECK: ^[[SET_ERROR]]:
|
||||
// CHECK: async.runtime.set_error %[[TOKEN]]
|
||||
// CHECK: br ^[[CLEANUP]]
|
||||
// CHECK: cf.br ^[[CLEANUP]]
|
||||
|
||||
// Delete coroutine.
|
||||
// CHECK: ^[[CLEANUP]]:
|
||||
|
@ -409,7 +409,7 @@ func @lower_scf_to_cfg(%arg0: f32, %arg1: memref<1xf32>, %arg2: i1) {
|
|||
|
||||
// Check that structured control flow lowered to CFG.
|
||||
// CHECK-NOT: scf.if
|
||||
// CHECK: cond_br %[[FLAG]]
|
||||
// CHECK: cf.cond_br %[[FLAG]]
|
||||
|
||||
// -----
|
||||
// Constants captured by the async.execute region should be cloned into the
|
||||
|
|
|
@ -17,26 +17,26 @@
|
|||
|
||||
// CHECK-LABEL: func @condBranch
|
||||
func @condBranch(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
|
||||
cond_br %arg0, ^bb1, ^bb2
|
||||
cf.cond_br %arg0, ^bb1, ^bb2
|
||||
^bb1:
|
||||
br ^bb3(%arg1 : memref<2xf32>)
|
||||
cf.br ^bb3(%arg1 : memref<2xf32>)
|
||||
^bb2:
|
||||
%0 = memref.alloc() : memref<2xf32>
|
||||
test.buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>)
|
||||
br ^bb3(%0 : memref<2xf32>)
|
||||
cf.br ^bb3(%0 : memref<2xf32>)
|
||||
^bb3(%1: memref<2xf32>):
|
||||
test.copy(%1, %arg2) : (memref<2xf32>, memref<2xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-NEXT: cond_br
|
||||
// CHECK-NEXT: cf.cond_br
|
||||
// CHECK: %[[ALLOC0:.*]] = bufferization.clone
|
||||
// CHECK-NEXT: br ^bb3(%[[ALLOC0]]
|
||||
// CHECK-NEXT: cf.br ^bb3(%[[ALLOC0]]
|
||||
// CHECK: %[[ALLOC1:.*]] = memref.alloc
|
||||
// CHECK-NEXT: test.buffer_based
|
||||
// CHECK-NEXT: %[[ALLOC2:.*]] = bufferization.clone %[[ALLOC1]]
|
||||
// CHECK-NEXT: memref.dealloc %[[ALLOC1]]
|
||||
// CHECK-NEXT: br ^bb3(%[[ALLOC2]]
|
||||
// CHECK-NEXT: cf.br ^bb3(%[[ALLOC2]]
|
||||
// CHECK: test.copy
|
||||
// CHECK-NEXT: memref.dealloc
|
||||
// CHECK-NEXT: return
|
||||
|
@ -62,27 +62,27 @@ func @condBranchDynamicType(
|
|||
%arg1: memref<?xf32>,
|
||||
%arg2: memref<?xf32>,
|
||||
%arg3: index) {
|
||||
cond_br %arg0, ^bb1, ^bb2(%arg3: index)
|
||||
cf.cond_br %arg0, ^bb1, ^bb2(%arg3: index)
|
||||
^bb1:
|
||||
br ^bb3(%arg1 : memref<?xf32>)
|
||||
cf.br ^bb3(%arg1 : memref<?xf32>)
|
||||
^bb2(%0: index):
|
||||
%1 = memref.alloc(%0) : memref<?xf32>
|
||||
test.buffer_based in(%arg1: memref<?xf32>) out(%1: memref<?xf32>)
|
||||
br ^bb3(%1 : memref<?xf32>)
|
||||
cf.br ^bb3(%1 : memref<?xf32>)
|
||||
^bb3(%2: memref<?xf32>):
|
||||
test.copy(%2, %arg2) : (memref<?xf32>, memref<?xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-NEXT: cond_br
|
||||
// CHECK-NEXT: cf.cond_br
|
||||
// CHECK: %[[ALLOC0:.*]] = bufferization.clone
|
||||
// CHECK-NEXT: br ^bb3(%[[ALLOC0]]
|
||||
// CHECK-NEXT: cf.br ^bb3(%[[ALLOC0]]
|
||||
// CHECK: ^bb2(%[[IDX:.*]]:{{.*}})
|
||||
// CHECK-NEXT: %[[ALLOC1:.*]] = memref.alloc(%[[IDX]])
|
||||
// CHECK-NEXT: test.buffer_based
|
||||
// CHECK-NEXT: %[[ALLOC2:.*]] = bufferization.clone
|
||||
// CHECK-NEXT: memref.dealloc %[[ALLOC1]]
|
||||
// CHECK-NEXT: br ^bb3
|
||||
// CHECK-NEXT: cf.br ^bb3
|
||||
// CHECK-NEXT: ^bb3(%[[ALLOC3:.*]]:{{.*}})
|
||||
// CHECK: test.copy(%[[ALLOC3]],
|
||||
// CHECK-NEXT: memref.dealloc %[[ALLOC3]]
|
||||
|
@ -98,28 +98,28 @@ func @condBranchUnrankedType(
|
|||
%arg1: memref<*xf32>,
|
||||
%arg2: memref<*xf32>,
|
||||
%arg3: index) {
|
||||
cond_br %arg0, ^bb1, ^bb2(%arg3: index)
|
||||
cf.cond_br %arg0, ^bb1, ^bb2(%arg3: index)
|
||||
^bb1:
|
||||
br ^bb3(%arg1 : memref<*xf32>)
|
||||
cf.br ^bb3(%arg1 : memref<*xf32>)
|
||||
^bb2(%0: index):
|
||||
%1 = memref.alloc(%0) : memref<?xf32>
|
||||
%2 = memref.cast %1 : memref<?xf32> to memref<*xf32>
|
||||
test.buffer_based in(%arg1: memref<*xf32>) out(%2: memref<*xf32>)
|
||||
br ^bb3(%2 : memref<*xf32>)
|
||||
cf.br ^bb3(%2 : memref<*xf32>)
|
||||
^bb3(%3: memref<*xf32>):
|
||||
test.copy(%3, %arg2) : (memref<*xf32>, memref<*xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-NEXT: cond_br
|
||||
// CHECK-NEXT: cf.cond_br
|
||||
// CHECK: %[[ALLOC0:.*]] = bufferization.clone
|
||||
// CHECK-NEXT: br ^bb3(%[[ALLOC0]]
|
||||
// CHECK-NEXT: cf.br ^bb3(%[[ALLOC0]]
|
||||
// CHECK: ^bb2(%[[IDX:.*]]:{{.*}})
|
||||
// CHECK-NEXT: %[[ALLOC1:.*]] = memref.alloc(%[[IDX]])
|
||||
// CHECK: test.buffer_based
|
||||
// CHECK-NEXT: %[[ALLOC2:.*]] = bufferization.clone
|
||||
// CHECK-NEXT: memref.dealloc %[[ALLOC1]]
|
||||
// CHECK-NEXT: br ^bb3
|
||||
// CHECK-NEXT: cf.br ^bb3
|
||||
// CHECK-NEXT: ^bb3(%[[ALLOC3:.*]]:{{.*}})
|
||||
// CHECK: test.copy(%[[ALLOC3]],
|
||||
// CHECK-NEXT: memref.dealloc %[[ALLOC3]]
|
||||
|
@ -153,44 +153,44 @@ func @condBranchDynamicTypeNested(
|
|||
%arg1: memref<?xf32>,
|
||||
%arg2: memref<?xf32>,
|
||||
%arg3: index) {
|
||||
cond_br %arg0, ^bb1, ^bb2(%arg3: index)
|
||||
cf.cond_br %arg0, ^bb1, ^bb2(%arg3: index)
|
||||
^bb1:
|
||||
br ^bb6(%arg1 : memref<?xf32>)
|
||||
cf.br ^bb6(%arg1 : memref<?xf32>)
|
||||
^bb2(%0: index):
|
||||
%1 = memref.alloc(%0) : memref<?xf32>
|
||||
test.buffer_based in(%arg1: memref<?xf32>) out(%1: memref<?xf32>)
|
||||
cond_br %arg0, ^bb3, ^bb4
|
||||
cf.cond_br %arg0, ^bb3, ^bb4
|
||||
^bb3:
|
||||
br ^bb5(%1 : memref<?xf32>)
|
||||
cf.br ^bb5(%1 : memref<?xf32>)
|
||||
^bb4:
|
||||
br ^bb5(%1 : memref<?xf32>)
|
||||
cf.br ^bb5(%1 : memref<?xf32>)
|
||||
^bb5(%2: memref<?xf32>):
|
||||
br ^bb6(%2 : memref<?xf32>)
|
||||
cf.br ^bb6(%2 : memref<?xf32>)
|
||||
^bb6(%3: memref<?xf32>):
|
||||
br ^bb7(%3 : memref<?xf32>)
|
||||
cf.br ^bb7(%3 : memref<?xf32>)
|
||||
^bb7(%4: memref<?xf32>):
|
||||
test.copy(%4, %arg2) : (memref<?xf32>, memref<?xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-NEXT: cond_br{{.*}}
|
||||
// CHECK-NEXT: cf.cond_br{{.*}}
|
||||
// CHECK-NEXT: ^bb1
|
||||
// CHECK-NEXT: %[[ALLOC0:.*]] = bufferization.clone
|
||||
// CHECK-NEXT: br ^bb6(%[[ALLOC0]]
|
||||
// CHECK-NEXT: cf.br ^bb6(%[[ALLOC0]]
|
||||
// CHECK: ^bb2(%[[IDX:.*]]:{{.*}})
|
||||
// CHECK-NEXT: %[[ALLOC1:.*]] = memref.alloc(%[[IDX]])
|
||||
// CHECK-NEXT: test.buffer_based
|
||||
// CHECK: cond_br
|
||||
// CHECK: cf.cond_br
|
||||
// CHECK: ^bb3:
|
||||
// CHECK-NEXT: br ^bb5(%[[ALLOC1]]{{.*}})
|
||||
// CHECK-NEXT: cf.br ^bb5(%[[ALLOC1]]{{.*}})
|
||||
// CHECK: ^bb4:
|
||||
// CHECK-NEXT: br ^bb5(%[[ALLOC1]]{{.*}})
|
||||
// CHECK-NEXT: cf.br ^bb5(%[[ALLOC1]]{{.*}})
|
||||
// CHECK-NEXT: ^bb5(%[[ALLOC2:.*]]:{{.*}})
|
||||
// CHECK-NEXT: %[[ALLOC3:.*]] = bufferization.clone %[[ALLOC2]]
|
||||
// CHECK-NEXT: memref.dealloc %[[ALLOC1]]
|
||||
// CHECK-NEXT: br ^bb6(%[[ALLOC3]]{{.*}})
|
||||
// CHECK-NEXT: cf.br ^bb6(%[[ALLOC3]]{{.*}})
|
||||
// CHECK-NEXT: ^bb6(%[[ALLOC4:.*]]:{{.*}})
|
||||
// CHECK-NEXT: br ^bb7(%[[ALLOC4]]{{.*}})
|
||||
// CHECK-NEXT: cf.br ^bb7(%[[ALLOC4]]{{.*}})
|
||||
// CHECK-NEXT: ^bb7(%[[ALLOC5:.*]]:{{.*}})
|
||||
// CHECK: test.copy(%[[ALLOC5]],
|
||||
// CHECK-NEXT: memref.dealloc %[[ALLOC4]]
|
||||
|
@ -225,18 +225,18 @@ func @emptyUsesValue(%arg0: memref<4xf32>) {
|
|||
|
||||
// CHECK-LABEL: func @criticalEdge
|
||||
func @criticalEdge(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
|
||||
cond_br %arg0, ^bb1, ^bb2(%arg1 : memref<2xf32>)
|
||||
cf.cond_br %arg0, ^bb1, ^bb2(%arg1 : memref<2xf32>)
|
||||
^bb1:
|
||||
%0 = memref.alloc() : memref<2xf32>
|
||||
test.buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>)
|
||||
br ^bb2(%0 : memref<2xf32>)
|
||||
cf.br ^bb2(%0 : memref<2xf32>)
|
||||
^bb2(%1: memref<2xf32>):
|
||||
test.copy(%1, %arg2) : (memref<2xf32>, memref<2xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-NEXT: %[[ALLOC0:.*]] = bufferization.clone
|
||||
// CHECK-NEXT: cond_br
|
||||
// CHECK-NEXT: cf.cond_br
|
||||
// CHECK: %[[ALLOC1:.*]] = memref.alloc()
|
||||
// CHECK-NEXT: test.buffer_based
|
||||
// CHECK-NEXT: %[[ALLOC2:.*]] = bufferization.clone %[[ALLOC1]]
|
||||
|
@ -260,9 +260,9 @@ func @criticalEdge(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
|
|||
func @invCriticalEdge(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
|
||||
%0 = memref.alloc() : memref<2xf32>
|
||||
test.buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>)
|
||||
cond_br %arg0, ^bb1, ^bb2(%arg1 : memref<2xf32>)
|
||||
cf.cond_br %arg0, ^bb1, ^bb2(%arg1 : memref<2xf32>)
|
||||
^bb1:
|
||||
br ^bb2(%0 : memref<2xf32>)
|
||||
cf.br ^bb2(%0 : memref<2xf32>)
|
||||
^bb2(%1: memref<2xf32>):
|
||||
test.copy(%1, %arg2) : (memref<2xf32>, memref<2xf32>)
|
||||
return
|
||||
|
@ -288,13 +288,13 @@ func @invCriticalEdge(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
|
|||
func @ifElse(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
|
||||
%0 = memref.alloc() : memref<2xf32>
|
||||
test.buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>)
|
||||
cond_br %arg0,
|
||||
cf.cond_br %arg0,
|
||||
^bb1(%arg1, %0 : memref<2xf32>, memref<2xf32>),
|
||||
^bb2(%0, %arg1 : memref<2xf32>, memref<2xf32>)
|
||||
^bb1(%1: memref<2xf32>, %2: memref<2xf32>):
|
||||
br ^bb3(%1, %2 : memref<2xf32>, memref<2xf32>)
|
||||
cf.br ^bb3(%1, %2 : memref<2xf32>, memref<2xf32>)
|
||||
^bb2(%3: memref<2xf32>, %4: memref<2xf32>):
|
||||
br ^bb3(%3, %4 : memref<2xf32>, memref<2xf32>)
|
||||
cf.br ^bb3(%3, %4 : memref<2xf32>, memref<2xf32>)
|
||||
^bb3(%5: memref<2xf32>, %6: memref<2xf32>):
|
||||
%7 = memref.alloc() : memref<2xf32>
|
||||
test.buffer_based in(%5: memref<2xf32>) out(%7: memref<2xf32>)
|
||||
|
@ -326,13 +326,13 @@ func @ifElse(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
|
|||
func @ifElseNoUsers(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
|
||||
%0 = memref.alloc() : memref<2xf32>
|
||||
test.buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>)
|
||||
cond_br %arg0,
|
||||
cf.cond_br %arg0,
|
||||
^bb1(%arg1, %0 : memref<2xf32>, memref<2xf32>),
|
||||
^bb2(%0, %arg1 : memref<2xf32>, memref<2xf32>)
|
||||
^bb1(%1: memref<2xf32>, %2: memref<2xf32>):
|
||||
br ^bb3(%1, %2 : memref<2xf32>, memref<2xf32>)
|
||||
cf.br ^bb3(%1, %2 : memref<2xf32>, memref<2xf32>)
|
||||
^bb2(%3: memref<2xf32>, %4: memref<2xf32>):
|
||||
br ^bb3(%3, %4 : memref<2xf32>, memref<2xf32>)
|
||||
cf.br ^bb3(%3, %4 : memref<2xf32>, memref<2xf32>)
|
||||
^bb3(%5: memref<2xf32>, %6: memref<2xf32>):
|
||||
test.copy(%arg1, %arg2) : (memref<2xf32>, memref<2xf32>)
|
||||
return
|
||||
|
@ -361,17 +361,17 @@ func @ifElseNoUsers(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
|
|||
func @ifElseNested(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
|
||||
%0 = memref.alloc() : memref<2xf32>
|
||||
test.buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>)
|
||||
cond_br %arg0,
|
||||
cf.cond_br %arg0,
|
||||
^bb1(%arg1, %0 : memref<2xf32>, memref<2xf32>),
|
||||
^bb2(%0, %arg1 : memref<2xf32>, memref<2xf32>)
|
||||
^bb1(%1: memref<2xf32>, %2: memref<2xf32>):
|
||||
br ^bb5(%1, %2 : memref<2xf32>, memref<2xf32>)
|
||||
cf.br ^bb5(%1, %2 : memref<2xf32>, memref<2xf32>)
|
||||
^bb2(%3: memref<2xf32>, %4: memref<2xf32>):
|
||||
cond_br %arg0, ^bb3(%3 : memref<2xf32>), ^bb4(%4 : memref<2xf32>)
|
||||
cf.cond_br %arg0, ^bb3(%3 : memref<2xf32>), ^bb4(%4 : memref<2xf32>)
|
||||
^bb3(%5: memref<2xf32>):
|
||||
br ^bb5(%5, %3 : memref<2xf32>, memref<2xf32>)
|
||||
cf.br ^bb5(%5, %3 : memref<2xf32>, memref<2xf32>)
|
||||
^bb4(%6: memref<2xf32>):
|
||||
br ^bb5(%3, %6 : memref<2xf32>, memref<2xf32>)
|
||||
cf.br ^bb5(%3, %6 : memref<2xf32>, memref<2xf32>)
|
||||
^bb5(%7: memref<2xf32>, %8: memref<2xf32>):
|
||||
%9 = memref.alloc() : memref<2xf32>
|
||||
test.buffer_based in(%7: memref<2xf32>) out(%9: memref<2xf32>)
|
||||
|
@ -430,33 +430,33 @@ func @moving_alloc_and_inserting_missing_dealloc(
|
|||
%cond: i1,
|
||||
%arg0: memref<2xf32>,
|
||||
%arg1: memref<2xf32>) {
|
||||
cond_br %cond, ^bb1, ^bb2
|
||||
cf.cond_br %cond, ^bb1, ^bb2
|
||||
^bb1:
|
||||
%0 = memref.alloc() : memref<2xf32>
|
||||
test.buffer_based in(%arg0: memref<2xf32>) out(%0: memref<2xf32>)
|
||||
br ^exit(%0 : memref<2xf32>)
|
||||
cf.br ^exit(%0 : memref<2xf32>)
|
||||
^bb2:
|
||||
%1 = memref.alloc() : memref<2xf32>
|
||||
test.buffer_based in(%arg0: memref<2xf32>) out(%1: memref<2xf32>)
|
||||
br ^exit(%1 : memref<2xf32>)
|
||||
cf.br ^exit(%1 : memref<2xf32>)
|
||||
^exit(%arg2: memref<2xf32>):
|
||||
test.copy(%arg2, %arg1) : (memref<2xf32>, memref<2xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-NEXT: cond_br{{.*}}
|
||||
// CHECK-NEXT: cf.cond_br{{.*}}
|
||||
// CHECK-NEXT: ^bb1
|
||||
// CHECK: %[[ALLOC0:.*]] = memref.alloc()
|
||||
// CHECK-NEXT: test.buffer_based
|
||||
// CHECK-NEXT: %[[ALLOC1:.*]] = bufferization.clone %[[ALLOC0]]
|
||||
// CHECK-NEXT: memref.dealloc %[[ALLOC0]]
|
||||
// CHECK-NEXT: br ^bb3(%[[ALLOC1]]
|
||||
// CHECK-NEXT: cf.br ^bb3(%[[ALLOC1]]
|
||||
// CHECK-NEXT: ^bb2
|
||||
// CHECK-NEXT: %[[ALLOC2:.*]] = memref.alloc()
|
||||
// CHECK-NEXT: test.buffer_based
|
||||
// CHECK-NEXT: %[[ALLOC3:.*]] = bufferization.clone %[[ALLOC2]]
|
||||
// CHECK-NEXT: memref.dealloc %[[ALLOC2]]
|
||||
// CHECK-NEXT: br ^bb3(%[[ALLOC3]]
|
||||
// CHECK-NEXT: cf.br ^bb3(%[[ALLOC3]]
|
||||
// CHECK-NEXT: ^bb3(%[[ALLOC4:.*]]:{{.*}})
|
||||
// CHECK: test.copy
|
||||
// CHECK-NEXT: memref.dealloc %[[ALLOC4]]
|
||||
|
@ -480,20 +480,20 @@ func @moving_invalid_dealloc_op_complex(
|
|||
%arg0: memref<2xf32>,
|
||||
%arg1: memref<2xf32>) {
|
||||
%1 = memref.alloc() : memref<2xf32>
|
||||
cond_br %cond, ^bb1, ^bb2
|
||||
cf.cond_br %cond, ^bb1, ^bb2
|
||||
^bb1:
|
||||
br ^exit(%arg0 : memref<2xf32>)
|
||||
cf.br ^exit(%arg0 : memref<2xf32>)
|
||||
^bb2:
|
||||
test.buffer_based in(%arg0: memref<2xf32>) out(%1: memref<2xf32>)
|
||||
memref.dealloc %1 : memref<2xf32>
|
||||
br ^exit(%1 : memref<2xf32>)
|
||||
cf.br ^exit(%1 : memref<2xf32>)
|
||||
^exit(%arg2: memref<2xf32>):
|
||||
test.copy(%arg2, %arg1) : (memref<2xf32>, memref<2xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-NEXT: %[[ALLOC0:.*]] = memref.alloc()
|
||||
// CHECK-NEXT: cond_br
|
||||
// CHECK-NEXT: cf.cond_br
|
||||
// CHECK: test.copy
|
||||
// CHECK-NEXT: memref.dealloc %[[ALLOC0]]
|
||||
// CHECK-NEXT: return
|
||||
|
@ -548,9 +548,9 @@ func @nested_regions_and_cond_branch(
|
|||
%arg0: i1,
|
||||
%arg1: memref<2xf32>,
|
||||
%arg2: memref<2xf32>) {
|
||||
cond_br %arg0, ^bb1, ^bb2
|
||||
cf.cond_br %arg0, ^bb1, ^bb2
|
||||
^bb1:
|
||||
br ^bb3(%arg1 : memref<2xf32>)
|
||||
cf.br ^bb3(%arg1 : memref<2xf32>)
|
||||
^bb2:
|
||||
%0 = memref.alloc() : memref<2xf32>
|
||||
test.region_buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>) {
|
||||
|
@ -560,13 +560,13 @@ func @nested_regions_and_cond_branch(
|
|||
%tmp1 = math.exp %gen1_arg0 : f32
|
||||
test.region_yield %tmp1 : f32
|
||||
}
|
||||
br ^bb3(%0 : memref<2xf32>)
|
||||
cf.br ^bb3(%0 : memref<2xf32>)
|
||||
^bb3(%1: memref<2xf32>):
|
||||
test.copy(%1, %arg2) : (memref<2xf32>, memref<2xf32>)
|
||||
return
|
||||
}
|
||||
// CHECK: (%[[cond:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %{{.*}}: {{.*}})
|
||||
// CHECK-NEXT: cond_br %[[cond]], ^[[BB1:.*]], ^[[BB2:.*]]
|
||||
// CHECK-NEXT: cf.cond_br %[[cond]], ^[[BB1:.*]], ^[[BB2:.*]]
|
||||
// CHECK: %[[ALLOC0:.*]] = bufferization.clone %[[ARG1]]
|
||||
// CHECK: ^[[BB2]]:
|
||||
// CHECK: %[[ALLOC1:.*]] = memref.alloc()
|
||||
|
@ -728,21 +728,21 @@ func @subview(%arg0 : index, %arg1 : index, %arg2 : memref<?x?xf32>) {
|
|||
|
||||
// CHECK-LABEL: func @condBranchAlloca
|
||||
func @condBranchAlloca(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
|
||||
cond_br %arg0, ^bb1, ^bb2
|
||||
cf.cond_br %arg0, ^bb1, ^bb2
|
||||
^bb1:
|
||||
br ^bb3(%arg1 : memref<2xf32>)
|
||||
cf.br ^bb3(%arg1 : memref<2xf32>)
|
||||
^bb2:
|
||||
%0 = memref.alloca() : memref<2xf32>
|
||||
test.buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>)
|
||||
br ^bb3(%0 : memref<2xf32>)
|
||||
cf.br ^bb3(%0 : memref<2xf32>)
|
||||
^bb3(%1: memref<2xf32>):
|
||||
test.copy(%1, %arg2) : (memref<2xf32>, memref<2xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-NEXT: cond_br
|
||||
// CHECK-NEXT: cf.cond_br
|
||||
// CHECK: %[[ALLOCA:.*]] = memref.alloca()
|
||||
// CHECK: br ^bb3(%[[ALLOCA:.*]])
|
||||
// CHECK: cf.br ^bb3(%[[ALLOCA:.*]])
|
||||
// CHECK-NEXT: ^bb3
|
||||
// CHECK-NEXT: test.copy
|
||||
// CHECK-NEXT: return
|
||||
|
@ -757,13 +757,13 @@ func @condBranchAlloca(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
|
|||
func @ifElseAlloca(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
|
||||
%0 = memref.alloc() : memref<2xf32>
|
||||
test.buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>)
|
||||
cond_br %arg0,
|
||||
cf.cond_br %arg0,
|
||||
^bb1(%arg1, %0 : memref<2xf32>, memref<2xf32>),
|
||||
^bb2(%0, %arg1 : memref<2xf32>, memref<2xf32>)
|
||||
^bb1(%1: memref<2xf32>, %2: memref<2xf32>):
|
||||
br ^bb3(%1, %2 : memref<2xf32>, memref<2xf32>)
|
||||
cf.br ^bb3(%1, %2 : memref<2xf32>, memref<2xf32>)
|
||||
^bb2(%3: memref<2xf32>, %4: memref<2xf32>):
|
||||
br ^bb3(%3, %4 : memref<2xf32>, memref<2xf32>)
|
||||
cf.br ^bb3(%3, %4 : memref<2xf32>, memref<2xf32>)
|
||||
^bb3(%5: memref<2xf32>, %6: memref<2xf32>):
|
||||
%7 = memref.alloca() : memref<2xf32>
|
||||
test.buffer_based in(%5: memref<2xf32>) out(%7: memref<2xf32>)
|
||||
|
@ -788,17 +788,17 @@ func @ifElseNestedAlloca(
|
|||
%arg2: memref<2xf32>) {
|
||||
%0 = memref.alloca() : memref<2xf32>
|
||||
test.buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>)
|
||||
cond_br %arg0,
|
||||
cf.cond_br %arg0,
|
||||
^bb1(%arg1, %0 : memref<2xf32>, memref<2xf32>),
|
||||
^bb2(%0, %arg1 : memref<2xf32>, memref<2xf32>)
|
||||
^bb1(%1: memref<2xf32>, %2: memref<2xf32>):
|
||||
br ^bb5(%1, %2 : memref<2xf32>, memref<2xf32>)
|
||||
cf.br ^bb5(%1, %2 : memref<2xf32>, memref<2xf32>)
|
||||
^bb2(%3: memref<2xf32>, %4: memref<2xf32>):
|
||||
cond_br %arg0, ^bb3(%3 : memref<2xf32>), ^bb4(%4 : memref<2xf32>)
|
||||
cf.cond_br %arg0, ^bb3(%3 : memref<2xf32>), ^bb4(%4 : memref<2xf32>)
|
||||
^bb3(%5: memref<2xf32>):
|
||||
br ^bb5(%5, %3 : memref<2xf32>, memref<2xf32>)
|
||||
cf.br ^bb5(%5, %3 : memref<2xf32>, memref<2xf32>)
|
||||
^bb4(%6: memref<2xf32>):
|
||||
br ^bb5(%3, %6 : memref<2xf32>, memref<2xf32>)
|
||||
cf.br ^bb5(%3, %6 : memref<2xf32>, memref<2xf32>)
|
||||
^bb5(%7: memref<2xf32>, %8: memref<2xf32>):
|
||||
%9 = memref.alloc() : memref<2xf32>
|
||||
test.buffer_based in(%7: memref<2xf32>) out(%9: memref<2xf32>)
|
||||
|
@ -821,9 +821,9 @@ func @nestedRegionsAndCondBranchAlloca(
|
|||
%arg0: i1,
|
||||
%arg1: memref<2xf32>,
|
||||
%arg2: memref<2xf32>) {
|
||||
cond_br %arg0, ^bb1, ^bb2
|
||||
cf.cond_br %arg0, ^bb1, ^bb2
|
||||
^bb1:
|
||||
br ^bb3(%arg1 : memref<2xf32>)
|
||||
cf.br ^bb3(%arg1 : memref<2xf32>)
|
||||
^bb2:
|
||||
%0 = memref.alloc() : memref<2xf32>
|
||||
test.region_buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>) {
|
||||
|
@ -833,13 +833,13 @@ func @nestedRegionsAndCondBranchAlloca(
|
|||
%tmp1 = math.exp %gen1_arg0 : f32
|
||||
test.region_yield %tmp1 : f32
|
||||
}
|
||||
br ^bb3(%0 : memref<2xf32>)
|
||||
cf.br ^bb3(%0 : memref<2xf32>)
|
||||
^bb3(%1: memref<2xf32>):
|
||||
test.copy(%1, %arg2) : (memref<2xf32>, memref<2xf32>)
|
||||
return
|
||||
}
|
||||
// CHECK: (%[[cond:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %{{.*}}: {{.*}})
|
||||
// CHECK-NEXT: cond_br %[[cond]], ^[[BB1:.*]], ^[[BB2:.*]]
|
||||
// CHECK-NEXT: cf.cond_br %[[cond]], ^[[BB1:.*]], ^[[BB2:.*]]
|
||||
// CHECK: ^[[BB1]]:
|
||||
// CHECK: %[[ALLOC0:.*]] = bufferization.clone
|
||||
// CHECK: ^[[BB2]]:
|
||||
|
@ -1103,11 +1103,11 @@ func @loop_dynalloc(
|
|||
%arg2: memref<?xf32>,
|
||||
%arg3: memref<?xf32>) {
|
||||
%const0 = arith.constant 0 : i32
|
||||
br ^loopHeader(%const0, %arg2 : i32, memref<?xf32>)
|
||||
cf.br ^loopHeader(%const0, %arg2 : i32, memref<?xf32>)
|
||||
|
||||
^loopHeader(%i : i32, %buff : memref<?xf32>):
|
||||
%lessThan = arith.cmpi slt, %i, %arg1 : i32
|
||||
cond_br %lessThan,
|
||||
cf.cond_br %lessThan,
|
||||
^loopBody(%i, %buff : i32, memref<?xf32>),
|
||||
^exit(%buff : memref<?xf32>)
|
||||
|
||||
|
@ -1116,7 +1116,7 @@ func @loop_dynalloc(
|
|||
%inc = arith.addi %val, %const1 : i32
|
||||
%size = arith.index_cast %inc : i32 to index
|
||||
%alloc1 = memref.alloc(%size) : memref<?xf32>
|
||||
br ^loopHeader(%inc, %alloc1 : i32, memref<?xf32>)
|
||||
cf.br ^loopHeader(%inc, %alloc1 : i32, memref<?xf32>)
|
||||
|
||||
^exit(%buff3 : memref<?xf32>):
|
||||
test.copy(%buff3, %arg3) : (memref<?xf32>, memref<?xf32>)
|
||||
|
@ -1136,17 +1136,17 @@ func @do_loop_alloc(
|
|||
%arg2: memref<2xf32>,
|
||||
%arg3: memref<2xf32>) {
|
||||
%const0 = arith.constant 0 : i32
|
||||
br ^loopBody(%const0, %arg2 : i32, memref<2xf32>)
|
||||
cf.br ^loopBody(%const0, %arg2 : i32, memref<2xf32>)
|
||||
|
||||
^loopBody(%val : i32, %buff2: memref<2xf32>):
|
||||
%const1 = arith.constant 1 : i32
|
||||
%inc = arith.addi %val, %const1 : i32
|
||||
%alloc1 = memref.alloc() : memref<2xf32>
|
||||
br ^loopHeader(%inc, %alloc1 : i32, memref<2xf32>)
|
||||
cf.br ^loopHeader(%inc, %alloc1 : i32, memref<2xf32>)
|
||||
|
||||
^loopHeader(%i : i32, %buff : memref<2xf32>):
|
||||
%lessThan = arith.cmpi slt, %i, %arg1 : i32
|
||||
cond_br %lessThan,
|
||||
cf.cond_br %lessThan,
|
||||
^loopBody(%i, %buff : i32, memref<2xf32>),
|
||||
^exit(%buff : memref<2xf32>)
|
||||
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue