forked from OSchip/llvm-project
[mlir] Split out a new ControlFlow dialect from Standard
This dialect is intended to model lower level/branch based control-flow constructs. The initial set of operations are: AssertOp, BranchOp, CondBranchOp, SwitchOp; all split out from the current standard dialect. See https://discourse.llvm.org/t/standard-dialect-the-final-chapter/6061 Differential Revision: https://reviews.llvm.org/D118966
This commit is contained in:
parent
edca177cbe
commit
ace01605e0
|
@ -27,8 +27,8 @@ namespace fir::support {
|
||||||
#define FLANG_NONCODEGEN_DIALECT_LIST \
|
#define FLANG_NONCODEGEN_DIALECT_LIST \
|
||||||
mlir::AffineDialect, FIROpsDialect, mlir::acc::OpenACCDialect, \
|
mlir::AffineDialect, FIROpsDialect, mlir::acc::OpenACCDialect, \
|
||||||
mlir::omp::OpenMPDialect, mlir::scf::SCFDialect, \
|
mlir::omp::OpenMPDialect, mlir::scf::SCFDialect, \
|
||||||
mlir::arith::ArithmeticDialect, mlir::StandardOpsDialect, \
|
mlir::arith::ArithmeticDialect, mlir::cf::ControlFlowDialect, \
|
||||||
mlir::vector::VectorDialect
|
mlir::StandardOpsDialect, mlir::vector::VectorDialect
|
||||||
|
|
||||||
// The definitive list of dialects used by flang.
|
// The definitive list of dialects used by flang.
|
||||||
#define FLANG_DIALECT_LIST \
|
#define FLANG_DIALECT_LIST \
|
||||||
|
|
|
@ -9,7 +9,7 @@
|
||||||
/// This file defines some shared command-line options that can be used when
|
/// This file defines some shared command-line options that can be used when
|
||||||
/// debugging the test tools. This file must be included into the tool.
|
/// debugging the test tools. This file must be included into the tool.
|
||||||
|
|
||||||
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
|
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
|
||||||
#include "mlir/Pass/PassManager.h"
|
#include "mlir/Pass/PassManager.h"
|
||||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
#include "mlir/Transforms/Passes.h"
|
#include "mlir/Transforms/Passes.h"
|
||||||
|
@ -139,7 +139,7 @@ inline void createDefaultFIROptimizerPassPipeline(mlir::PassManager &pm) {
|
||||||
|
|
||||||
// convert control flow to CFG form
|
// convert control flow to CFG form
|
||||||
fir::addCfgConversionPass(pm);
|
fir::addCfgConversionPass(pm);
|
||||||
pm.addPass(mlir::createLowerToCFGPass());
|
pm.addPass(mlir::createConvertSCFToCFPass());
|
||||||
|
|
||||||
pm.addPass(mlir::createCanonicalizerPass(config));
|
pm.addPass(mlir::createCanonicalizerPass(config));
|
||||||
}
|
}
|
||||||
|
|
|
@ -32,7 +32,7 @@ add_flang_library(FortranLower
|
||||||
FortranSemantics
|
FortranSemantics
|
||||||
MLIRAffineToStandard
|
MLIRAffineToStandard
|
||||||
MLIRLLVMIR
|
MLIRLLVMIR
|
||||||
MLIRSCFToStandard
|
MLIRSCFToControlFlow
|
||||||
MLIRStandard
|
MLIRStandard
|
||||||
|
|
||||||
LINK_COMPONENTS
|
LINK_COMPONENTS
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
#include "flang/Optimizer/Dialect/FIROps.h"
|
#include "flang/Optimizer/Dialect/FIROps.h"
|
||||||
#include "flang/Optimizer/Support/TypeCode.h"
|
#include "flang/Optimizer/Support/TypeCode.h"
|
||||||
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
|
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
|
||||||
|
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
|
||||||
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
||||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
@ -3293,6 +3294,8 @@ public:
|
||||||
mlir::populateStdToLLVMConversionPatterns(typeConverter, pattern);
|
mlir::populateStdToLLVMConversionPatterns(typeConverter, pattern);
|
||||||
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
|
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
|
||||||
pattern);
|
pattern);
|
||||||
|
mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter,
|
||||||
|
pattern);
|
||||||
mlir::ConversionTarget target{*context};
|
mlir::ConversionTarget target{*context};
|
||||||
target.addLegalDialect<mlir::LLVM::LLVMDialect>();
|
target.addLegalDialect<mlir::LLVM::LLVMDialect>();
|
||||||
|
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
#include "flang/Optimizer/Dialect/FIRDialect.h"
|
#include "flang/Optimizer/Dialect/FIRDialect.h"
|
||||||
#include "flang/Optimizer/Support/FIRContext.h"
|
#include "flang/Optimizer/Support/FIRContext.h"
|
||||||
#include "flang/Optimizer/Transforms/Passes.h"
|
#include "flang/Optimizer/Transforms/Passes.h"
|
||||||
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||||
#include "mlir/Dialect/SCF/SCF.h"
|
#include "mlir/Dialect/SCF/SCF.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
#include "llvm/Support/Debug.h"
|
#include "llvm/Support/Debug.h"
|
||||||
|
@ -332,9 +333,9 @@ ArrayCopyAnalysis::arrayAccesses(ArrayLoadOp load) {
|
||||||
<< "add modify {" << *owner << "} to array value set\n");
|
<< "add modify {" << *owner << "} to array value set\n");
|
||||||
accesses.push_back(owner);
|
accesses.push_back(owner);
|
||||||
appendToQueue(update.getResult(1));
|
appendToQueue(update.getResult(1));
|
||||||
} else if (auto br = mlir::dyn_cast<mlir::BranchOp>(owner)) {
|
} else if (auto br = mlir::dyn_cast<mlir::cf::BranchOp>(owner)) {
|
||||||
branchOp(br.getDest(), br.getDestOperands());
|
branchOp(br.getDest(), br.getDestOperands());
|
||||||
} else if (auto br = mlir::dyn_cast<mlir::CondBranchOp>(owner)) {
|
} else if (auto br = mlir::dyn_cast<mlir::cf::CondBranchOp>(owner)) {
|
||||||
branchOp(br.getTrueDest(), br.getTrueOperands());
|
branchOp(br.getTrueDest(), br.getTrueOperands());
|
||||||
branchOp(br.getFalseDest(), br.getFalseOperands());
|
branchOp(br.getFalseDest(), br.getFalseOperands());
|
||||||
} else if (mlir::isa<ArrayMergeStoreOp>(owner)) {
|
} else if (mlir::isa<ArrayMergeStoreOp>(owner)) {
|
||||||
|
@ -789,9 +790,9 @@ public:
|
||||||
patterns1.insert<ArrayUpdateConversion>(context, analysis, useMap);
|
patterns1.insert<ArrayUpdateConversion>(context, analysis, useMap);
|
||||||
patterns1.insert<ArrayModifyConversion>(context, analysis, useMap);
|
patterns1.insert<ArrayModifyConversion>(context, analysis, useMap);
|
||||||
mlir::ConversionTarget target(*context);
|
mlir::ConversionTarget target(*context);
|
||||||
target.addLegalDialect<FIROpsDialect, mlir::scf::SCFDialect,
|
target.addLegalDialect<
|
||||||
mlir::arith::ArithmeticDialect,
|
FIROpsDialect, mlir::scf::SCFDialect, mlir::arith::ArithmeticDialect,
|
||||||
mlir::StandardOpsDialect>();
|
mlir::cf::ControlFlowDialect, mlir::StandardOpsDialect>();
|
||||||
target.addIllegalOp<ArrayFetchOp, ArrayUpdateOp, ArrayModifyOp>();
|
target.addIllegalOp<ArrayFetchOp, ArrayUpdateOp, ArrayModifyOp>();
|
||||||
// Rewrite the array fetch and array update ops.
|
// Rewrite the array fetch and array update ops.
|
||||||
if (mlir::failed(
|
if (mlir::failed(
|
||||||
|
|
|
@ -11,6 +11,7 @@
|
||||||
#include "flang/Optimizer/Dialect/FIROps.h"
|
#include "flang/Optimizer/Dialect/FIROps.h"
|
||||||
#include "flang/Optimizer/Transforms/Passes.h"
|
#include "flang/Optimizer/Transforms/Passes.h"
|
||||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||||
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
@ -84,7 +85,7 @@ public:
|
||||||
loopOperands.append(operands.begin(), operands.end());
|
loopOperands.append(operands.begin(), operands.end());
|
||||||
loopOperands.push_back(iters);
|
loopOperands.push_back(iters);
|
||||||
|
|
||||||
rewriter.create<mlir::BranchOp>(loc, conditionalBlock, loopOperands);
|
rewriter.create<mlir::cf::BranchOp>(loc, conditionalBlock, loopOperands);
|
||||||
|
|
||||||
// Last loop block
|
// Last loop block
|
||||||
auto *terminator = lastBlock->getTerminator();
|
auto *terminator = lastBlock->getTerminator();
|
||||||
|
@ -105,7 +106,7 @@ public:
|
||||||
: terminator->operand_begin();
|
: terminator->operand_begin();
|
||||||
loopCarried.append(begin, terminator->operand_end());
|
loopCarried.append(begin, terminator->operand_end());
|
||||||
loopCarried.push_back(itersMinusOne);
|
loopCarried.push_back(itersMinusOne);
|
||||||
rewriter.create<mlir::BranchOp>(loc, conditionalBlock, loopCarried);
|
rewriter.create<mlir::cf::BranchOp>(loc, conditionalBlock, loopCarried);
|
||||||
rewriter.eraseOp(terminator);
|
rewriter.eraseOp(terminator);
|
||||||
|
|
||||||
// Conditional block
|
// Conditional block
|
||||||
|
@ -114,9 +115,9 @@ public:
|
||||||
auto comparison = rewriter.create<mlir::arith::CmpIOp>(
|
auto comparison = rewriter.create<mlir::arith::CmpIOp>(
|
||||||
loc, arith::CmpIPredicate::sgt, itersLeft, zero);
|
loc, arith::CmpIPredicate::sgt, itersLeft, zero);
|
||||||
|
|
||||||
rewriter.create<mlir::CondBranchOp>(loc, comparison, firstBlock,
|
rewriter.create<mlir::cf::CondBranchOp>(
|
||||||
llvm::ArrayRef<mlir::Value>(), endBlock,
|
loc, comparison, firstBlock, llvm::ArrayRef<mlir::Value>(), endBlock,
|
||||||
llvm::ArrayRef<mlir::Value>());
|
llvm::ArrayRef<mlir::Value>());
|
||||||
|
|
||||||
// The result of the loop operation is the values of the condition block
|
// The result of the loop operation is the values of the condition block
|
||||||
// arguments except the induction variable on the last iteration.
|
// arguments except the induction variable on the last iteration.
|
||||||
|
@ -155,7 +156,7 @@ public:
|
||||||
} else {
|
} else {
|
||||||
continueBlock =
|
continueBlock =
|
||||||
rewriter.createBlock(remainingOpsBlock, ifOp.getResultTypes());
|
rewriter.createBlock(remainingOpsBlock, ifOp.getResultTypes());
|
||||||
rewriter.create<mlir::BranchOp>(loc, remainingOpsBlock);
|
rewriter.create<mlir::cf::BranchOp>(loc, remainingOpsBlock);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Move blocks from the "then" region to the region containing 'fir.if',
|
// Move blocks from the "then" region to the region containing 'fir.if',
|
||||||
|
@ -165,7 +166,8 @@ public:
|
||||||
auto *ifOpTerminator = ifOpRegion.back().getTerminator();
|
auto *ifOpTerminator = ifOpRegion.back().getTerminator();
|
||||||
auto ifOpTerminatorOperands = ifOpTerminator->getOperands();
|
auto ifOpTerminatorOperands = ifOpTerminator->getOperands();
|
||||||
rewriter.setInsertionPointToEnd(&ifOpRegion.back());
|
rewriter.setInsertionPointToEnd(&ifOpRegion.back());
|
||||||
rewriter.create<mlir::BranchOp>(loc, continueBlock, ifOpTerminatorOperands);
|
rewriter.create<mlir::cf::BranchOp>(loc, continueBlock,
|
||||||
|
ifOpTerminatorOperands);
|
||||||
rewriter.eraseOp(ifOpTerminator);
|
rewriter.eraseOp(ifOpTerminator);
|
||||||
rewriter.inlineRegionBefore(ifOpRegion, continueBlock);
|
rewriter.inlineRegionBefore(ifOpRegion, continueBlock);
|
||||||
|
|
||||||
|
@ -179,14 +181,14 @@ public:
|
||||||
auto *otherwiseTerm = otherwiseRegion.back().getTerminator();
|
auto *otherwiseTerm = otherwiseRegion.back().getTerminator();
|
||||||
auto otherwiseTermOperands = otherwiseTerm->getOperands();
|
auto otherwiseTermOperands = otherwiseTerm->getOperands();
|
||||||
rewriter.setInsertionPointToEnd(&otherwiseRegion.back());
|
rewriter.setInsertionPointToEnd(&otherwiseRegion.back());
|
||||||
rewriter.create<mlir::BranchOp>(loc, continueBlock,
|
rewriter.create<mlir::cf::BranchOp>(loc, continueBlock,
|
||||||
otherwiseTermOperands);
|
otherwiseTermOperands);
|
||||||
rewriter.eraseOp(otherwiseTerm);
|
rewriter.eraseOp(otherwiseTerm);
|
||||||
rewriter.inlineRegionBefore(otherwiseRegion, continueBlock);
|
rewriter.inlineRegionBefore(otherwiseRegion, continueBlock);
|
||||||
}
|
}
|
||||||
|
|
||||||
rewriter.setInsertionPointToEnd(condBlock);
|
rewriter.setInsertionPointToEnd(condBlock);
|
||||||
rewriter.create<mlir::CondBranchOp>(
|
rewriter.create<mlir::cf::CondBranchOp>(
|
||||||
loc, ifOp.condition(), ifOpBlock, llvm::ArrayRef<mlir::Value>(),
|
loc, ifOp.condition(), ifOpBlock, llvm::ArrayRef<mlir::Value>(),
|
||||||
otherwiseBlock, llvm::ArrayRef<mlir::Value>());
|
otherwiseBlock, llvm::ArrayRef<mlir::Value>());
|
||||||
rewriter.replaceOp(ifOp, continueBlock->getArguments());
|
rewriter.replaceOp(ifOp, continueBlock->getArguments());
|
||||||
|
@ -241,7 +243,7 @@ public:
|
||||||
auto begin = whileOp.finalValue() ? std::next(terminator->operand_begin())
|
auto begin = whileOp.finalValue() ? std::next(terminator->operand_begin())
|
||||||
: terminator->operand_begin();
|
: terminator->operand_begin();
|
||||||
loopCarried.append(begin, terminator->operand_end());
|
loopCarried.append(begin, terminator->operand_end());
|
||||||
rewriter.create<mlir::BranchOp>(loc, conditionBlock, loopCarried);
|
rewriter.create<mlir::cf::BranchOp>(loc, conditionBlock, loopCarried);
|
||||||
rewriter.eraseOp(terminator);
|
rewriter.eraseOp(terminator);
|
||||||
|
|
||||||
// Compute loop bounds before branching to the condition.
|
// Compute loop bounds before branching to the condition.
|
||||||
|
@ -256,7 +258,7 @@ public:
|
||||||
destOperands.push_back(lowerBound);
|
destOperands.push_back(lowerBound);
|
||||||
auto iterOperands = whileOp.getIterOperands();
|
auto iterOperands = whileOp.getIterOperands();
|
||||||
destOperands.append(iterOperands.begin(), iterOperands.end());
|
destOperands.append(iterOperands.begin(), iterOperands.end());
|
||||||
rewriter.create<mlir::BranchOp>(loc, conditionBlock, destOperands);
|
rewriter.create<mlir::cf::BranchOp>(loc, conditionBlock, destOperands);
|
||||||
|
|
||||||
// With the body block done, we can fill in the condition block.
|
// With the body block done, we can fill in the condition block.
|
||||||
rewriter.setInsertionPointToEnd(conditionBlock);
|
rewriter.setInsertionPointToEnd(conditionBlock);
|
||||||
|
@ -278,9 +280,9 @@ public:
|
||||||
// Remember to AND in the early-exit bool.
|
// Remember to AND in the early-exit bool.
|
||||||
auto comparison =
|
auto comparison =
|
||||||
rewriter.create<mlir::arith::AndIOp>(loc, iterateVar, cmp2);
|
rewriter.create<mlir::arith::AndIOp>(loc, iterateVar, cmp2);
|
||||||
rewriter.create<mlir::CondBranchOp>(loc, comparison, firstBodyBlock,
|
rewriter.create<mlir::cf::CondBranchOp>(
|
||||||
llvm::ArrayRef<mlir::Value>(), endBlock,
|
loc, comparison, firstBodyBlock, llvm::ArrayRef<mlir::Value>(),
|
||||||
llvm::ArrayRef<mlir::Value>());
|
endBlock, llvm::ArrayRef<mlir::Value>());
|
||||||
// The result of the loop operation is the values of the condition block
|
// The result of the loop operation is the values of the condition block
|
||||||
// arguments except the induction variable on the last iteration.
|
// arguments except the induction variable on the last iteration.
|
||||||
auto args = whileOp.finalValue()
|
auto args = whileOp.finalValue()
|
||||||
|
@ -300,8 +302,8 @@ public:
|
||||||
patterns.insert<CfgLoopConv, CfgIfConv, CfgIterWhileConv>(
|
patterns.insert<CfgLoopConv, CfgIfConv, CfgIterWhileConv>(
|
||||||
context, forceLoopToExecuteOnce);
|
context, forceLoopToExecuteOnce);
|
||||||
mlir::ConversionTarget target(*context);
|
mlir::ConversionTarget target(*context);
|
||||||
target.addLegalDialect<mlir::AffineDialect, FIROpsDialect,
|
target.addLegalDialect<mlir::AffineDialect, mlir::cf::ControlFlowDialect,
|
||||||
mlir::StandardOpsDialect>();
|
FIROpsDialect, mlir::StandardOpsDialect>();
|
||||||
|
|
||||||
// apply the patterns
|
// apply the patterns
|
||||||
target.addIllegalOp<ResultOp, DoLoopOp, IfOp, IterWhileOp>();
|
target.addIllegalOp<ResultOp, DoLoopOp, IfOp, IterWhileOp>();
|
||||||
|
|
|
@ -10,10 +10,10 @@ func @select_case_charachter(%arg0: !fir.char<2, 10>, %arg1: !fir.char<2, 10>, %
|
||||||
unit, ^bb3]
|
unit, ^bb3]
|
||||||
^bb1:
|
^bb1:
|
||||||
%c1_i32 = arith.constant 1 : i32
|
%c1_i32 = arith.constant 1 : i32
|
||||||
br ^bb3
|
cf.br ^bb3
|
||||||
^bb2:
|
^bb2:
|
||||||
%c2_i32 = arith.constant 2 : i32
|
%c2_i32 = arith.constant 2 : i32
|
||||||
br ^bb3
|
cf.br ^bb3
|
||||||
^bb3:
|
^bb3:
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -1175,23 +1175,23 @@ func @select_case_integer(%arg0: !fir.ref<i32>) -> i32 {
|
||||||
^bb1: // pred: ^bb0
|
^bb1: // pred: ^bb0
|
||||||
%c1_i32_0 = arith.constant 1 : i32
|
%c1_i32_0 = arith.constant 1 : i32
|
||||||
fir.store %c1_i32_0 to %arg0 : !fir.ref<i32>
|
fir.store %c1_i32_0 to %arg0 : !fir.ref<i32>
|
||||||
br ^bb6
|
cf.br ^bb6
|
||||||
^bb2: // pred: ^bb0
|
^bb2: // pred: ^bb0
|
||||||
%c2_i32_1 = arith.constant 2 : i32
|
%c2_i32_1 = arith.constant 2 : i32
|
||||||
fir.store %c2_i32_1 to %arg0 : !fir.ref<i32>
|
fir.store %c2_i32_1 to %arg0 : !fir.ref<i32>
|
||||||
br ^bb6
|
cf.br ^bb6
|
||||||
^bb3: // pred: ^bb0
|
^bb3: // pred: ^bb0
|
||||||
%c0_i32 = arith.constant 0 : i32
|
%c0_i32 = arith.constant 0 : i32
|
||||||
fir.store %c0_i32 to %arg0 : !fir.ref<i32>
|
fir.store %c0_i32 to %arg0 : !fir.ref<i32>
|
||||||
br ^bb6
|
cf.br ^bb6
|
||||||
^bb4: // pred: ^bb0
|
^bb4: // pred: ^bb0
|
||||||
%c4_i32_2 = arith.constant 4 : i32
|
%c4_i32_2 = arith.constant 4 : i32
|
||||||
fir.store %c4_i32_2 to %arg0 : !fir.ref<i32>
|
fir.store %c4_i32_2 to %arg0 : !fir.ref<i32>
|
||||||
br ^bb6
|
cf.br ^bb6
|
||||||
^bb5: // 3 preds: ^bb0, ^bb0, ^bb0
|
^bb5: // 3 preds: ^bb0, ^bb0, ^bb0
|
||||||
%c7_i32_3 = arith.constant 7 : i32
|
%c7_i32_3 = arith.constant 7 : i32
|
||||||
fir.store %c7_i32_3 to %arg0 : !fir.ref<i32>
|
fir.store %c7_i32_3 to %arg0 : !fir.ref<i32>
|
||||||
br ^bb6
|
cf.br ^bb6
|
||||||
^bb6: // 5 preds: ^bb1, ^bb2, ^bb3, ^bb4, ^bb5
|
^bb6: // 5 preds: ^bb1, ^bb2, ^bb3, ^bb4, ^bb5
|
||||||
%3 = fir.load %arg0 : !fir.ref<i32>
|
%3 = fir.load %arg0 : !fir.ref<i32>
|
||||||
return %3 : i32
|
return %3 : i32
|
||||||
|
@ -1275,10 +1275,10 @@ func @select_case_logical(%arg0: !fir.ref<!fir.logical<4>>) {
|
||||||
unit, ^bb3]
|
unit, ^bb3]
|
||||||
^bb1:
|
^bb1:
|
||||||
%c1_i32 = arith.constant 1 : i32
|
%c1_i32 = arith.constant 1 : i32
|
||||||
br ^bb3
|
cf.br ^bb3
|
||||||
^bb2:
|
^bb2:
|
||||||
%c2_i32 = arith.constant 2 : i32
|
%c2_i32 = arith.constant 2 : i32
|
||||||
br ^bb3
|
cf.br ^bb3
|
||||||
^bb3:
|
^bb3:
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,10 +9,10 @@ func @load_store_chain_removal(%arg0: !fir.ref<!fir.array<60xi32>>, %arg1: !fir.
|
||||||
%c1 = arith.constant 1 : index
|
%c1 = arith.constant 1 : index
|
||||||
%0 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFf1dcEi"}
|
%0 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFf1dcEi"}
|
||||||
%1 = fir.alloca !fir.array<60xi32> {bindc_name = "t1", uniq_name = "_QFf1dcEt1"}
|
%1 = fir.alloca !fir.array<60xi32> {bindc_name = "t1", uniq_name = "_QFf1dcEt1"}
|
||||||
br ^bb1(%c1, %c60 : index, index)
|
cf.br ^bb1(%c1, %c60 : index, index)
|
||||||
^bb1(%2: index, %3: index): // 2 preds: ^bb0, ^bb2
|
^bb1(%2: index, %3: index): // 2 preds: ^bb0, ^bb2
|
||||||
%4 = arith.cmpi sgt, %3, %c0 : index
|
%4 = arith.cmpi sgt, %3, %c0 : index
|
||||||
cond_br %4, ^bb2, ^bb3
|
cf.cond_br %4, ^bb2, ^bb3
|
||||||
^bb2: // pred: ^bb1
|
^bb2: // pred: ^bb1
|
||||||
%5 = fir.convert %2 : (index) -> i32
|
%5 = fir.convert %2 : (index) -> i32
|
||||||
fir.store %5 to %0 : !fir.ref<i32>
|
fir.store %5 to %0 : !fir.ref<i32>
|
||||||
|
@ -26,14 +26,14 @@ func @load_store_chain_removal(%arg0: !fir.ref<!fir.array<60xi32>>, %arg1: !fir.
|
||||||
fir.store %11 to %12 : !fir.ref<i32>
|
fir.store %11 to %12 : !fir.ref<i32>
|
||||||
%13 = arith.addi %2, %c1 : index
|
%13 = arith.addi %2, %c1 : index
|
||||||
%14 = arith.subi %3, %c1 : index
|
%14 = arith.subi %3, %c1 : index
|
||||||
br ^bb1(%13, %14 : index, index)
|
cf.br ^bb1(%13, %14 : index, index)
|
||||||
^bb3: // pred: ^bb1
|
^bb3: // pred: ^bb1
|
||||||
%15 = fir.convert %2 : (index) -> i32
|
%15 = fir.convert %2 : (index) -> i32
|
||||||
fir.store %15 to %0 : !fir.ref<i32>
|
fir.store %15 to %0 : !fir.ref<i32>
|
||||||
br ^bb4(%c1, %c60 : index, index)
|
cf.br ^bb4(%c1, %c60 : index, index)
|
||||||
^bb4(%16: index, %17: index): // 2 preds: ^bb3, ^bb5
|
^bb4(%16: index, %17: index): // 2 preds: ^bb3, ^bb5
|
||||||
%18 = arith.cmpi sgt, %17, %c0 : index
|
%18 = arith.cmpi sgt, %17, %c0 : index
|
||||||
cond_br %18, ^bb5, ^bb6
|
cf.cond_br %18, ^bb5, ^bb6
|
||||||
^bb5: // pred: ^bb4
|
^bb5: // pred: ^bb4
|
||||||
%19 = fir.convert %16 : (index) -> i32
|
%19 = fir.convert %16 : (index) -> i32
|
||||||
fir.store %19 to %0 : !fir.ref<i32>
|
fir.store %19 to %0 : !fir.ref<i32>
|
||||||
|
@ -49,7 +49,7 @@ func @load_store_chain_removal(%arg0: !fir.ref<!fir.array<60xi32>>, %arg1: !fir.
|
||||||
fir.store %27 to %28 : !fir.ref<i32>
|
fir.store %27 to %28 : !fir.ref<i32>
|
||||||
%29 = arith.addi %16, %c1 : index
|
%29 = arith.addi %16, %c1 : index
|
||||||
%30 = arith.subi %17, %c1 : index
|
%30 = arith.subi %17, %c1 : index
|
||||||
br ^bb4(%29, %30 : index, index)
|
cf.br ^bb4(%29, %30 : index, index)
|
||||||
^bb6: // pred: ^bb4
|
^bb6: // pred: ^bb4
|
||||||
%31 = fir.convert %16 : (index) -> i32
|
%31 = fir.convert %16 : (index) -> i32
|
||||||
fir.store %31 to %0 : !fir.ref<i32>
|
fir.store %31 to %0 : !fir.ref<i32>
|
||||||
|
|
|
@ -13,7 +13,7 @@ FIRTransforms
|
||||||
FIRBuilder
|
FIRBuilder
|
||||||
${dialect_libs}
|
${dialect_libs}
|
||||||
MLIRAffineToStandard
|
MLIRAffineToStandard
|
||||||
MLIRSCFToStandard
|
MLIRSCFToControlFlow
|
||||||
FortranCommon
|
FortranCommon
|
||||||
FortranParser
|
FortranParser
|
||||||
FortranEvaluate
|
FortranEvaluate
|
||||||
|
|
|
@ -38,7 +38,6 @@
|
||||||
#include "flang/Semantics/semantics.h"
|
#include "flang/Semantics/semantics.h"
|
||||||
#include "flang/Semantics/unparse-with-symbols.h"
|
#include "flang/Semantics/unparse-with-symbols.h"
|
||||||
#include "flang/Version.inc"
|
#include "flang/Version.inc"
|
||||||
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
|
|
||||||
#include "mlir/IR/AsmState.h"
|
#include "mlir/IR/AsmState.h"
|
||||||
#include "mlir/IR/BuiltinOps.h"
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
#include "mlir/IR/MLIRContext.h"
|
#include "mlir/IR/MLIRContext.h"
|
||||||
|
|
|
@ -18,7 +18,7 @@ target_link_libraries(fir-opt PRIVATE
|
||||||
MLIRTransforms
|
MLIRTransforms
|
||||||
MLIRAffineToStandard
|
MLIRAffineToStandard
|
||||||
MLIRAnalysis
|
MLIRAnalysis
|
||||||
MLIRSCFToStandard
|
MLIRSCFToControlFlow
|
||||||
MLIRParser
|
MLIRParser
|
||||||
MLIRStandardToLLVM
|
MLIRStandardToLLVM
|
||||||
MLIRSupport
|
MLIRSupport
|
||||||
|
|
|
@ -17,7 +17,7 @@ target_link_libraries(tco PRIVATE
|
||||||
MLIRTransforms
|
MLIRTransforms
|
||||||
MLIRAffineToStandard
|
MLIRAffineToStandard
|
||||||
MLIRAnalysis
|
MLIRAnalysis
|
||||||
MLIRSCFToStandard
|
MLIRSCFToControlFlow
|
||||||
MLIRParser
|
MLIRParser
|
||||||
MLIRStandardToLLVM
|
MLIRStandardToLLVM
|
||||||
MLIRSupport
|
MLIRSupport
|
||||||
|
|
|
@ -17,7 +17,6 @@
|
||||||
#include "flang/Optimizer/Support/InternalNames.h"
|
#include "flang/Optimizer/Support/InternalNames.h"
|
||||||
#include "flang/Optimizer/Support/KindMapping.h"
|
#include "flang/Optimizer/Support/KindMapping.h"
|
||||||
#include "flang/Optimizer/Transforms/Passes.h"
|
#include "flang/Optimizer/Transforms/Passes.h"
|
||||||
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
|
|
||||||
#include "mlir/IR/AsmState.h"
|
#include "mlir/IR/AsmState.h"
|
||||||
#include "mlir/IR/BuiltinOps.h"
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
#include "mlir/IR/MLIRContext.h"
|
#include "mlir/IR/MLIRContext.h"
|
||||||
|
|
|
@ -26,7 +26,7 @@ def setup_passes(mlir_module):
|
||||||
f"sparse-tensor-conversion,"
|
f"sparse-tensor-conversion,"
|
||||||
f"builtin.func"
|
f"builtin.func"
|
||||||
f"(linalg-bufferize,convert-linalg-to-loops,convert-vector-to-scf),"
|
f"(linalg-bufferize,convert-linalg-to-loops,convert-vector-to-scf),"
|
||||||
f"convert-scf-to-std,"
|
f"convert-scf-to-cf,"
|
||||||
f"func-bufferize,"
|
f"func-bufferize,"
|
||||||
f"arith-bufferize,"
|
f"arith-bufferize,"
|
||||||
f"builtin.func(tensor-bufferize,finalizing-bufferize),"
|
f"builtin.func(tensor-bufferize,finalizing-bufferize),"
|
||||||
|
|
|
@ -41,12 +41,12 @@ Example for breaking the invariant:
|
||||||
```mlir
|
```mlir
|
||||||
func @condBranch(%arg0: i1, %arg1: memref<2xf32>) {
|
func @condBranch(%arg0: i1, %arg1: memref<2xf32>) {
|
||||||
%0 = memref.alloc() : memref<2xf32>
|
%0 = memref.alloc() : memref<2xf32>
|
||||||
cond_br %arg0, ^bb1, ^bb2
|
cf.cond_br %arg0, ^bb1, ^bb2
|
||||||
^bb1:
|
^bb1:
|
||||||
br ^bb3()
|
cf.br ^bb3()
|
||||||
^bb2:
|
^bb2:
|
||||||
partial_write(%0, %0)
|
partial_write(%0, %0)
|
||||||
br ^bb3()
|
cf.br ^bb3()
|
||||||
^bb3():
|
^bb3():
|
||||||
test.copy(%0, %arg1) : (memref<2xf32>, memref<2xf32>) -> ()
|
test.copy(%0, %arg1) : (memref<2xf32>, memref<2xf32>) -> ()
|
||||||
return
|
return
|
||||||
|
@ -74,13 +74,13 @@ untracked allocations are mixed:
|
||||||
func @mixedAllocation(%arg0: i1) {
|
func @mixedAllocation(%arg0: i1) {
|
||||||
%0 = memref.alloca() : memref<2xf32> // aliases: %2
|
%0 = memref.alloca() : memref<2xf32> // aliases: %2
|
||||||
%1 = memref.alloc() : memref<2xf32> // aliases: %2
|
%1 = memref.alloc() : memref<2xf32> // aliases: %2
|
||||||
cond_br %arg0, ^bb1, ^bb2
|
cf.cond_br %arg0, ^bb1, ^bb2
|
||||||
^bb1:
|
^bb1:
|
||||||
use(%0)
|
use(%0)
|
||||||
br ^bb3(%0 : memref<2xf32>)
|
cf.br ^bb3(%0 : memref<2xf32>)
|
||||||
^bb2:
|
^bb2:
|
||||||
use(%1)
|
use(%1)
|
||||||
br ^bb3(%1 : memref<2xf32>)
|
cf.br ^bb3(%1 : memref<2xf32>)
|
||||||
^bb3(%2: memref<2xf32>):
|
^bb3(%2: memref<2xf32>):
|
||||||
...
|
...
|
||||||
}
|
}
|
||||||
|
@ -129,13 +129,13 @@ BufferHoisting pass:
|
||||||
|
|
||||||
```mlir
|
```mlir
|
||||||
func @condBranch(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
|
func @condBranch(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
|
||||||
cond_br %arg0, ^bb1, ^bb2
|
cf.cond_br %arg0, ^bb1, ^bb2
|
||||||
^bb1:
|
^bb1:
|
||||||
br ^bb3(%arg1 : memref<2xf32>)
|
cf.br ^bb3(%arg1 : memref<2xf32>)
|
||||||
^bb2:
|
^bb2:
|
||||||
%0 = memref.alloc() : memref<2xf32> // aliases: %1
|
%0 = memref.alloc() : memref<2xf32> // aliases: %1
|
||||||
use(%0)
|
use(%0)
|
||||||
br ^bb3(%0 : memref<2xf32>)
|
cf.br ^bb3(%0 : memref<2xf32>)
|
||||||
^bb3(%1: memref<2xf32>): // %1 could be %0 or %arg1
|
^bb3(%1: memref<2xf32>): // %1 could be %0 or %arg1
|
||||||
test.copy(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
|
test.copy(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
|
||||||
return
|
return
|
||||||
|
@ -150,12 +150,12 @@ of code:
|
||||||
```mlir
|
```mlir
|
||||||
func @condBranch(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
|
func @condBranch(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
|
||||||
%0 = memref.alloc() : memref<2xf32> // moved to bb0
|
%0 = memref.alloc() : memref<2xf32> // moved to bb0
|
||||||
cond_br %arg0, ^bb1, ^bb2
|
cf.cond_br %arg0, ^bb1, ^bb2
|
||||||
^bb1:
|
^bb1:
|
||||||
br ^bb3(%arg1 : memref<2xf32>)
|
cf.br ^bb3(%arg1 : memref<2xf32>)
|
||||||
^bb2:
|
^bb2:
|
||||||
use(%0)
|
use(%0)
|
||||||
br ^bb3(%0 : memref<2xf32>)
|
cf.br ^bb3(%0 : memref<2xf32>)
|
||||||
^bb3(%1: memref<2xf32>):
|
^bb3(%1: memref<2xf32>):
|
||||||
test.copy(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
|
test.copy(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
|
||||||
return
|
return
|
||||||
|
@ -175,14 +175,14 @@ func @condBranchDynamicType(
|
||||||
%arg1: memref<?xf32>,
|
%arg1: memref<?xf32>,
|
||||||
%arg2: memref<?xf32>,
|
%arg2: memref<?xf32>,
|
||||||
%arg3: index) {
|
%arg3: index) {
|
||||||
cond_br %arg0, ^bb1, ^bb2(%arg3: index)
|
cf.cond_br %arg0, ^bb1, ^bb2(%arg3: index)
|
||||||
^bb1:
|
^bb1:
|
||||||
br ^bb3(%arg1 : memref<?xf32>)
|
cf.br ^bb3(%arg1 : memref<?xf32>)
|
||||||
^bb2(%0: index):
|
^bb2(%0: index):
|
||||||
%1 = memref.alloc(%0) : memref<?xf32> // cannot be moved upwards to the data
|
%1 = memref.alloc(%0) : memref<?xf32> // cannot be moved upwards to the data
|
||||||
// dependency to %0
|
// dependency to %0
|
||||||
use(%1)
|
use(%1)
|
||||||
br ^bb3(%1 : memref<?xf32>)
|
cf.br ^bb3(%1 : memref<?xf32>)
|
||||||
^bb3(%2: memref<?xf32>):
|
^bb3(%2: memref<?xf32>):
|
||||||
test.copy(%2, %arg2) : (memref<?xf32>, memref<?xf32>) -> ()
|
test.copy(%2, %arg2) : (memref<?xf32>, memref<?xf32>) -> ()
|
||||||
return
|
return
|
||||||
|
@ -201,14 +201,14 @@ allocations have already been placed:
|
||||||
```mlir
|
```mlir
|
||||||
func @branch(%arg0: i1) {
|
func @branch(%arg0: i1) {
|
||||||
%0 = memref.alloc() : memref<2xf32> // aliases: %2
|
%0 = memref.alloc() : memref<2xf32> // aliases: %2
|
||||||
cond_br %arg0, ^bb1, ^bb2
|
cf.cond_br %arg0, ^bb1, ^bb2
|
||||||
^bb1:
|
^bb1:
|
||||||
%1 = memref.alloc() : memref<2xf32> // resides here for demonstration purposes
|
%1 = memref.alloc() : memref<2xf32> // resides here for demonstration purposes
|
||||||
// aliases: %2
|
// aliases: %2
|
||||||
br ^bb3(%1 : memref<2xf32>)
|
cf.br ^bb3(%1 : memref<2xf32>)
|
||||||
^bb2:
|
^bb2:
|
||||||
use(%0)
|
use(%0)
|
||||||
br ^bb3(%0 : memref<2xf32>)
|
cf.br ^bb3(%0 : memref<2xf32>)
|
||||||
^bb3(%2: memref<2xf32>):
|
^bb3(%2: memref<2xf32>):
|
||||||
…
|
…
|
||||||
return
|
return
|
||||||
|
@ -233,16 +233,16 @@ result:
|
||||||
```mlir
|
```mlir
|
||||||
func @branch(%arg0: i1) {
|
func @branch(%arg0: i1) {
|
||||||
%0 = memref.alloc() : memref<2xf32>
|
%0 = memref.alloc() : memref<2xf32>
|
||||||
cond_br %arg0, ^bb1, ^bb2
|
cf.cond_br %arg0, ^bb1, ^bb2
|
||||||
^bb1:
|
^bb1:
|
||||||
%1 = memref.alloc() : memref<2xf32>
|
%1 = memref.alloc() : memref<2xf32>
|
||||||
%3 = bufferization.clone %1 : (memref<2xf32>) -> (memref<2xf32>)
|
%3 = bufferization.clone %1 : (memref<2xf32>) -> (memref<2xf32>)
|
||||||
memref.dealloc %1 : memref<2xf32> // %1 can be safely freed here
|
memref.dealloc %1 : memref<2xf32> // %1 can be safely freed here
|
||||||
br ^bb3(%3 : memref<2xf32>)
|
cf.br ^bb3(%3 : memref<2xf32>)
|
||||||
^bb2:
|
^bb2:
|
||||||
use(%0)
|
use(%0)
|
||||||
%4 = bufferization.clone %0 : (memref<2xf32>) -> (memref<2xf32>)
|
%4 = bufferization.clone %0 : (memref<2xf32>) -> (memref<2xf32>)
|
||||||
br ^bb3(%4 : memref<2xf32>)
|
cf.br ^bb3(%4 : memref<2xf32>)
|
||||||
^bb3(%2: memref<2xf32>):
|
^bb3(%2: memref<2xf32>):
|
||||||
…
|
…
|
||||||
memref.dealloc %2 : memref<2xf32> // free temp buffer %2
|
memref.dealloc %2 : memref<2xf32> // free temp buffer %2
|
||||||
|
@ -273,23 +273,23 @@ func @condBranchDynamicTypeNested(
|
||||||
%arg1: memref<?xf32>, // aliases: %3, %4
|
%arg1: memref<?xf32>, // aliases: %3, %4
|
||||||
%arg2: memref<?xf32>,
|
%arg2: memref<?xf32>,
|
||||||
%arg3: index) {
|
%arg3: index) {
|
||||||
cond_br %arg0, ^bb1, ^bb2(%arg3: index)
|
cf.cond_br %arg0, ^bb1, ^bb2(%arg3: index)
|
||||||
^bb1:
|
^bb1:
|
||||||
br ^bb6(%arg1 : memref<?xf32>)
|
cf.br ^bb6(%arg1 : memref<?xf32>)
|
||||||
^bb2(%0: index):
|
^bb2(%0: index):
|
||||||
%1 = memref.alloc(%0) : memref<?xf32> // cannot be moved upwards due to the data
|
%1 = memref.alloc(%0) : memref<?xf32> // cannot be moved upwards due to the data
|
||||||
// dependency to %0
|
// dependency to %0
|
||||||
// aliases: %2, %3, %4
|
// aliases: %2, %3, %4
|
||||||
use(%1)
|
use(%1)
|
||||||
cond_br %arg0, ^bb3, ^bb4
|
cf.cond_br %arg0, ^bb3, ^bb4
|
||||||
^bb3:
|
^bb3:
|
||||||
br ^bb5(%1 : memref<?xf32>)
|
cf.br ^bb5(%1 : memref<?xf32>)
|
||||||
^bb4:
|
^bb4:
|
||||||
br ^bb5(%1 : memref<?xf32>)
|
cf.br ^bb5(%1 : memref<?xf32>)
|
||||||
^bb5(%2: memref<?xf32>): // non-crit. alias of %1, since %1 dominates %2
|
^bb5(%2: memref<?xf32>): // non-crit. alias of %1, since %1 dominates %2
|
||||||
br ^bb6(%2 : memref<?xf32>)
|
cf.br ^bb6(%2 : memref<?xf32>)
|
||||||
^bb6(%3: memref<?xf32>): // crit. alias of %arg1 and %2 (in other words %1)
|
^bb6(%3: memref<?xf32>): // crit. alias of %arg1 and %2 (in other words %1)
|
||||||
br ^bb7(%3 : memref<?xf32>)
|
cf.br ^bb7(%3 : memref<?xf32>)
|
||||||
^bb7(%4: memref<?xf32>): // non-crit. alias of %3, since %3 dominates %4
|
^bb7(%4: memref<?xf32>): // non-crit. alias of %3, since %3 dominates %4
|
||||||
test.copy(%4, %arg2) : (memref<?xf32>, memref<?xf32>) -> ()
|
test.copy(%4, %arg2) : (memref<?xf32>, memref<?xf32>) -> ()
|
||||||
return
|
return
|
||||||
|
@ -306,25 +306,25 @@ func @condBranchDynamicTypeNested(
|
||||||
%arg1: memref<?xf32>,
|
%arg1: memref<?xf32>,
|
||||||
%arg2: memref<?xf32>,
|
%arg2: memref<?xf32>,
|
||||||
%arg3: index) {
|
%arg3: index) {
|
||||||
cond_br %arg0, ^bb1, ^bb2(%arg3 : index)
|
cf.cond_br %arg0, ^bb1, ^bb2(%arg3 : index)
|
||||||
^bb1:
|
^bb1:
|
||||||
// temp buffer required due to alias %3
|
// temp buffer required due to alias %3
|
||||||
%5 = bufferization.clone %arg1 : (memref<?xf32>) -> (memref<?xf32>)
|
%5 = bufferization.clone %arg1 : (memref<?xf32>) -> (memref<?xf32>)
|
||||||
br ^bb6(%5 : memref<?xf32>)
|
cf.br ^bb6(%5 : memref<?xf32>)
|
||||||
^bb2(%0: index):
|
^bb2(%0: index):
|
||||||
%1 = memref.alloc(%0) : memref<?xf32>
|
%1 = memref.alloc(%0) : memref<?xf32>
|
||||||
use(%1)
|
use(%1)
|
||||||
cond_br %arg0, ^bb3, ^bb4
|
cf.cond_br %arg0, ^bb3, ^bb4
|
||||||
^bb3:
|
^bb3:
|
||||||
br ^bb5(%1 : memref<?xf32>)
|
cf.br ^bb5(%1 : memref<?xf32>)
|
||||||
^bb4:
|
^bb4:
|
||||||
br ^bb5(%1 : memref<?xf32>)
|
cf.br ^bb5(%1 : memref<?xf32>)
|
||||||
^bb5(%2: memref<?xf32>):
|
^bb5(%2: memref<?xf32>):
|
||||||
%6 = bufferization.clone %1 : (memref<?xf32>) -> (memref<?xf32>)
|
%6 = bufferization.clone %1 : (memref<?xf32>) -> (memref<?xf32>)
|
||||||
memref.dealloc %1 : memref<?xf32>
|
memref.dealloc %1 : memref<?xf32>
|
||||||
br ^bb6(%6 : memref<?xf32>)
|
cf.br ^bb6(%6 : memref<?xf32>)
|
||||||
^bb6(%3: memref<?xf32>):
|
^bb6(%3: memref<?xf32>):
|
||||||
br ^bb7(%3 : memref<?xf32>)
|
cf.br ^bb7(%3 : memref<?xf32>)
|
||||||
^bb7(%4: memref<?xf32>):
|
^bb7(%4: memref<?xf32>):
|
||||||
test.copy(%4, %arg2) : (memref<?xf32>, memref<?xf32>) -> ()
|
test.copy(%4, %arg2) : (memref<?xf32>, memref<?xf32>) -> ()
|
||||||
memref.dealloc %3 : memref<?xf32> // free %3, since %4 is a non-crit. alias of %3
|
memref.dealloc %3 : memref<?xf32> // free %3, since %4 is a non-crit. alias of %3
|
||||||
|
|
|
@ -295,7 +295,7 @@ A few examples are shown below:
|
||||||
```mlir
|
```mlir
|
||||||
// Expect an error on the same line.
|
// Expect an error on the same line.
|
||||||
func @bad_branch() {
|
func @bad_branch() {
|
||||||
br ^missing // expected-error {{reference to an undefined block}}
|
cf.br ^missing // expected-error {{reference to an undefined block}}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Expect an error on an adjacent line.
|
// Expect an error on an adjacent line.
|
||||||
|
|
|
@ -114,8 +114,8 @@ struct MyTarget : public ConversionTarget {
|
||||||
/// All operations within the GPU dialect are illegal.
|
/// All operations within the GPU dialect are illegal.
|
||||||
addIllegalDialect<GPUDialect>();
|
addIllegalDialect<GPUDialect>();
|
||||||
|
|
||||||
/// Mark `std.br` and `std.cond_br` as illegal.
|
/// Mark `cf.br` and `cf.cond_br` as illegal.
|
||||||
addIllegalOp<BranchOp, CondBranchOp>();
|
addIllegalOp<cf::BranchOp, cf::CondBranchOp>();
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Implement the default legalization handler to handle operations marked as
|
/// Implement the default legalization handler to handle operations marked as
|
||||||
|
|
|
@ -23,10 +23,11 @@ argument `-declare-variables-at-top`.
|
||||||
Besides operations part of the EmitC dialect, the Cpp targets supports
|
Besides operations part of the EmitC dialect, the Cpp targets supports
|
||||||
translating the following operations:
|
translating the following operations:
|
||||||
|
|
||||||
|
* 'cf' Dialect
|
||||||
|
* `cf.br`
|
||||||
|
* `cf.cond_br`
|
||||||
* 'std' Dialect
|
* 'std' Dialect
|
||||||
* `std.br`
|
|
||||||
* `std.call`
|
* `std.call`
|
||||||
* `std.cond_br`
|
|
||||||
* `std.constant`
|
* `std.constant`
|
||||||
* `std.return`
|
* `std.return`
|
||||||
* 'scf' Dialect
|
* 'scf' Dialect
|
||||||
|
|
|
@ -391,21 +391,21 @@ arguments:
|
||||||
```mlir
|
```mlir
|
||||||
func @simple(i64, i1) -> i64 {
|
func @simple(i64, i1) -> i64 {
|
||||||
^bb0(%a: i64, %cond: i1): // Code dominated by ^bb0 may refer to %a
|
^bb0(%a: i64, %cond: i1): // Code dominated by ^bb0 may refer to %a
|
||||||
cond_br %cond, ^bb1, ^bb2
|
cf.cond_br %cond, ^bb1, ^bb2
|
||||||
|
|
||||||
^bb1:
|
^bb1:
|
||||||
br ^bb3(%a: i64) // Branch passes %a as the argument
|
cf.br ^bb3(%a: i64) // Branch passes %a as the argument
|
||||||
|
|
||||||
^bb2:
|
^bb2:
|
||||||
%b = arith.addi %a, %a : i64
|
%b = arith.addi %a, %a : i64
|
||||||
br ^bb3(%b: i64) // Branch passes %b as the argument
|
cf.br ^bb3(%b: i64) // Branch passes %b as the argument
|
||||||
|
|
||||||
// ^bb3 receives an argument, named %c, from predecessors
|
// ^bb3 receives an argument, named %c, from predecessors
|
||||||
// and passes it on to bb4 along with %a. %a is referenced
|
// and passes it on to bb4 along with %a. %a is referenced
|
||||||
// directly from its defining operation and is not passed through
|
// directly from its defining operation and is not passed through
|
||||||
// an argument of ^bb3.
|
// an argument of ^bb3.
|
||||||
^bb3(%c: i64):
|
^bb3(%c: i64):
|
||||||
br ^bb4(%c, %a : i64, i64)
|
cf.br ^bb4(%c, %a : i64, i64)
|
||||||
|
|
||||||
^bb4(%d : i64, %e : i64):
|
^bb4(%d : i64, %e : i64):
|
||||||
%0 = arith.addi %d, %e : i64
|
%0 = arith.addi %d, %e : i64
|
||||||
|
@ -525,12 +525,12 @@ Example:
|
||||||
```mlir
|
```mlir
|
||||||
func @accelerator_compute(i64, i1) -> i64 { // An SSACFG region
|
func @accelerator_compute(i64, i1) -> i64 { // An SSACFG region
|
||||||
^bb0(%a: i64, %cond: i1): // Code dominated by ^bb0 may refer to %a
|
^bb0(%a: i64, %cond: i1): // Code dominated by ^bb0 may refer to %a
|
||||||
cond_br %cond, ^bb1, ^bb2
|
cf.cond_br %cond, ^bb1, ^bb2
|
||||||
|
|
||||||
^bb1:
|
^bb1:
|
||||||
// This def for %value does not dominate ^bb2
|
// This def for %value does not dominate ^bb2
|
||||||
%value = "op.convert"(%a) : (i64) -> i64
|
%value = "op.convert"(%a) : (i64) -> i64
|
||||||
br ^bb3(%a: i64) // Branch passes %a as the argument
|
cf.br ^bb3(%a: i64) // Branch passes %a as the argument
|
||||||
|
|
||||||
^bb2:
|
^bb2:
|
||||||
accelerator.launch() { // An SSACFG region
|
accelerator.launch() { // An SSACFG region
|
||||||
|
|
|
@ -356,24 +356,24 @@ Example output is shown below:
|
||||||
|
|
||||||
```
|
```
|
||||||
//===-------------------------------------------===//
|
//===-------------------------------------------===//
|
||||||
Processing operation : 'std.cond_br'(0x60f000001120) {
|
Processing operation : 'cf.cond_br'(0x60f000001120) {
|
||||||
"std.cond_br"(%arg0)[^bb2, ^bb2] {operand_segment_sizes = dense<[1, 0, 0]> : vector<3xi32>} : (i1) -> ()
|
"cf.cond_br"(%arg0)[^bb2, ^bb2] {operand_segment_sizes = dense<[1, 0, 0]> : vector<3xi32>} : (i1) -> ()
|
||||||
|
|
||||||
* Pattern SimplifyConstCondBranchPred : 'std.cond_br -> ()' {
|
* Pattern SimplifyConstCondBranchPred : 'cf.cond_br -> ()' {
|
||||||
} -> failure : pattern failed to match
|
} -> failure : pattern failed to match
|
||||||
|
|
||||||
* Pattern SimplifyCondBranchIdenticalSuccessors : 'std.cond_br -> ()' {
|
* Pattern SimplifyCondBranchIdenticalSuccessors : 'cf.cond_br -> ()' {
|
||||||
** Insert : 'std.br'(0x60b000003690)
|
** Insert : 'cf.br'(0x60b000003690)
|
||||||
** Replace : 'std.cond_br'(0x60f000001120)
|
** Replace : 'cf.cond_br'(0x60f000001120)
|
||||||
} -> success : pattern applied successfully
|
} -> success : pattern applied successfully
|
||||||
} -> success : pattern matched
|
} -> success : pattern matched
|
||||||
//===-------------------------------------------===//
|
//===-------------------------------------------===//
|
||||||
```
|
```
|
||||||
|
|
||||||
This output is describing the processing of a `std.cond_br` operation. We first
|
This output is describing the processing of a `cf.cond_br` operation. We first
|
||||||
try to apply the `SimplifyConstCondBranchPred`, which fails. From there, another
|
try to apply the `SimplifyConstCondBranchPred`, which fails. From there, another
|
||||||
pattern (`SimplifyCondBranchIdenticalSuccessors`) is applied that matches the
|
pattern (`SimplifyCondBranchIdenticalSuccessors`) is applied that matches the
|
||||||
`std.cond_br` and replaces it with a `std.br`.
|
`cf.cond_br` and replaces it with a `cf.br`.
|
||||||
|
|
||||||
## Debugging
|
## Debugging
|
||||||
|
|
||||||
|
|
|
@ -560,24 +560,24 @@ func @search(%A: memref<?x?xi32>, %S: <?xi32>, %key : i32) {
|
||||||
|
|
||||||
func @search_body(%A: memref<?x?xi32>, %S: memref<?xi32>, %key: i32, %i : i32) {
|
func @search_body(%A: memref<?x?xi32>, %S: memref<?xi32>, %key: i32, %i : i32) {
|
||||||
%nj = memref.dim %A, 1 : memref<?x?xi32>
|
%nj = memref.dim %A, 1 : memref<?x?xi32>
|
||||||
br ^bb1(0)
|
cf.br ^bb1(0)
|
||||||
|
|
||||||
^bb1(%j: i32)
|
^bb1(%j: i32)
|
||||||
%p1 = arith.cmpi "lt", %j, %nj : i32
|
%p1 = arith.cmpi "lt", %j, %nj : i32
|
||||||
cond_br %p1, ^bb2, ^bb5
|
cf.cond_br %p1, ^bb2, ^bb5
|
||||||
|
|
||||||
^bb2:
|
^bb2:
|
||||||
%v = affine.load %A[%i, %j] : memref<?x?xi32>
|
%v = affine.load %A[%i, %j] : memref<?x?xi32>
|
||||||
%p2 = arith.cmpi "eq", %v, %key : i32
|
%p2 = arith.cmpi "eq", %v, %key : i32
|
||||||
cond_br %p2, ^bb3(%j), ^bb4
|
cf.cond_br %p2, ^bb3(%j), ^bb4
|
||||||
|
|
||||||
^bb3(%j: i32)
|
^bb3(%j: i32)
|
||||||
affine.store %j, %S[%i] : memref<?xi32>
|
affine.store %j, %S[%i] : memref<?xi32>
|
||||||
br ^bb5
|
cf.br ^bb5
|
||||||
|
|
||||||
^bb4:
|
^bb4:
|
||||||
%jinc = arith.addi %j, 1 : i32
|
%jinc = arith.addi %j, 1 : i32
|
||||||
br ^bb1(%jinc)
|
cf.br ^bb1(%jinc)
|
||||||
|
|
||||||
^bb5:
|
^bb5:
|
||||||
return
|
return
|
||||||
|
|
|
@ -94,10 +94,11 @@ multiple stages by relying on
|
||||||
```c++
|
```c++
|
||||||
mlir::RewritePatternSet patterns(&getContext());
|
mlir::RewritePatternSet patterns(&getContext());
|
||||||
mlir::populateAffineToStdConversionPatterns(patterns, &getContext());
|
mlir::populateAffineToStdConversionPatterns(patterns, &getContext());
|
||||||
mlir::populateLoopToStdConversionPatterns(patterns, &getContext());
|
mlir::cf::populateSCFToControlFlowConversionPatterns(patterns, &getContext());
|
||||||
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
|
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
|
||||||
patterns);
|
patterns);
|
||||||
mlir::populateStdToLLVMConversionPatterns(typeConverter, patterns);
|
mlir::populateStdToLLVMConversionPatterns(typeConverter, patterns);
|
||||||
|
mlir::cf::populateControlFlowToLLVMConversionPatterns(patterns, &getContext());
|
||||||
|
|
||||||
// The only remaining operation, to lower from the `toy` dialect, is the
|
// The only remaining operation, to lower from the `toy` dialect, is the
|
||||||
// PrintOp.
|
// PrintOp.
|
||||||
|
@ -207,7 +208,7 @@ define void @main() {
|
||||||
%109 = memref.load double, double* %108
|
%109 = memref.load double, double* %108
|
||||||
%110 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double %109)
|
%110 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double %109)
|
||||||
%111 = add i64 %100, 1
|
%111 = add i64 %100, 1
|
||||||
br label %99
|
cf.br label %99
|
||||||
|
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
|
@ -361,7 +361,7 @@
|
||||||
</tspan></tspan><tspan
|
</tspan></tspan><tspan
|
||||||
x="73.476562"
|
x="73.476562"
|
||||||
y="88.293896"><tspan
|
y="88.293896"><tspan
|
||||||
style="font-size:5.64444px">br bb3(%0)</tspan></tspan></text>
|
style="font-size:5.64444px">cf.br bb3(%0)</tspan></tspan></text>
|
||||||
<text
|
<text
|
||||||
xml:space="preserve"
|
xml:space="preserve"
|
||||||
id="text1894"
|
id="text1894"
|
||||||
|
|
Before Width: | Height: | Size: 15 KiB After Width: | Height: | Size: 15 KiB |
|
@ -362,7 +362,7 @@
|
||||||
</tspan></tspan><tspan
|
</tspan></tspan><tspan
|
||||||
x="73.476562"
|
x="73.476562"
|
||||||
y="88.293896"><tspan
|
y="88.293896"><tspan
|
||||||
style="font-size:5.64444px">br bb3(%0)</tspan></tspan></text>
|
style="font-size:5.64444px">cf.br bb3(%0)</tspan></tspan></text>
|
||||||
<text
|
<text
|
||||||
xml:space="preserve"
|
xml:space="preserve"
|
||||||
id="text1894"
|
id="text1894"
|
||||||
|
|
Before Width: | Height: | Size: 15 KiB After Width: | Height: | Size: 15 KiB |
|
@ -26,10 +26,11 @@
|
||||||
|
|
||||||
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
|
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
|
||||||
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
|
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
|
||||||
|
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
|
||||||
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
|
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
|
||||||
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
|
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
|
||||||
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
|
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
|
||||||
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
|
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
|
||||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
||||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
||||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||||
|
@ -200,10 +201,11 @@ void ToyToLLVMLoweringPass::runOnOperation() {
|
||||||
// set of legal ones.
|
// set of legal ones.
|
||||||
RewritePatternSet patterns(&getContext());
|
RewritePatternSet patterns(&getContext());
|
||||||
populateAffineToStdConversionPatterns(patterns);
|
populateAffineToStdConversionPatterns(patterns);
|
||||||
populateLoopToStdConversionPatterns(patterns);
|
populateSCFToControlFlowConversionPatterns(patterns);
|
||||||
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
|
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
|
||||||
patterns);
|
patterns);
|
||||||
populateMemRefToLLVMConversionPatterns(typeConverter, patterns);
|
populateMemRefToLLVMConversionPatterns(typeConverter, patterns);
|
||||||
|
cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns);
|
||||||
populateStdToLLVMConversionPatterns(typeConverter, patterns);
|
populateStdToLLVMConversionPatterns(typeConverter, patterns);
|
||||||
|
|
||||||
// The only remaining operation to lower from the `toy` dialect, is the
|
// The only remaining operation to lower from the `toy` dialect, is the
|
||||||
|
|
|
@ -26,10 +26,11 @@
|
||||||
|
|
||||||
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
|
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
|
||||||
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
|
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
|
||||||
|
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
|
||||||
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
|
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
|
||||||
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
|
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
|
||||||
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
|
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
|
||||||
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
|
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
|
||||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
||||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
||||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||||
|
@ -200,10 +201,11 @@ void ToyToLLVMLoweringPass::runOnOperation() {
|
||||||
// set of legal ones.
|
// set of legal ones.
|
||||||
RewritePatternSet patterns(&getContext());
|
RewritePatternSet patterns(&getContext());
|
||||||
populateAffineToStdConversionPatterns(patterns);
|
populateAffineToStdConversionPatterns(patterns);
|
||||||
populateLoopToStdConversionPatterns(patterns);
|
populateSCFToControlFlowConversionPatterns(patterns);
|
||||||
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
|
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
|
||||||
patterns);
|
patterns);
|
||||||
populateMemRefToLLVMConversionPatterns(typeConverter, patterns);
|
populateMemRefToLLVMConversionPatterns(typeConverter, patterns);
|
||||||
|
cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns);
|
||||||
populateStdToLLVMConversionPatterns(typeConverter, patterns);
|
populateStdToLLVMConversionPatterns(typeConverter, patterns);
|
||||||
|
|
||||||
// The only remaining operation to lower from the `toy` dialect, is the
|
// The only remaining operation to lower from the `toy` dialect, is the
|
||||||
|
|
|
@ -0,0 +1,35 @@
|
||||||
|
//===- ControlFlowToLLVM.h - ControlFlow to LLVM -----------*- C++ ------*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Define conversions from the ControlFlow dialect to the LLVM IR dialect.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef MLIR_CONVERSION_CONTROLFLOWTOLLVM_CONTROLFLOWTOLLVM_H
|
||||||
|
#define MLIR_CONVERSION_CONTROLFLOWTOLLVM_CONTROLFLOWTOLLVM_H
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
class LLVMTypeConverter;
|
||||||
|
class RewritePatternSet;
|
||||||
|
class Pass;
|
||||||
|
|
||||||
|
namespace cf {
|
||||||
|
/// Collect the patterns to convert from the ControlFlow dialect to LLVM. The
|
||||||
|
/// conversion patterns capture the LLVMTypeConverter by reference meaning the
|
||||||
|
/// references have to remain alive during the entire pattern lifetime.
|
||||||
|
void populateControlFlowToLLVMConversionPatterns(LLVMTypeConverter &converter,
|
||||||
|
RewritePatternSet &patterns);
|
||||||
|
|
||||||
|
/// Creates a pass to convert the ControlFlow dialect into the LLVMIR dialect.
|
||||||
|
std::unique_ptr<Pass> createConvertControlFlowToLLVMPass();
|
||||||
|
} // namespace cf
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // MLIR_CONVERSION_CONTROLFLOWTOLLVM_CONTROLFLOWTOLLVM_H
|
|
@ -0,0 +1,28 @@
|
||||||
|
//===- ControlFlowToSPIRV.h - CF to SPIR-V Patterns --------*- C++ ------*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Provides patterns to convert ControlFlow dialect to SPIR-V dialect.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef MLIR_CONVERSION_CONTROLFLOWTOSPIRV_CONTROLFLOWTOSPIRV_H
|
||||||
|
#define MLIR_CONVERSION_CONTROLFLOWTOSPIRV_CONTROLFLOWTOSPIRV_H
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
class RewritePatternSet;
|
||||||
|
class SPIRVTypeConverter;
|
||||||
|
|
||||||
|
namespace cf {
|
||||||
|
/// Appends to a pattern list additional patterns for translating ControlFLow
|
||||||
|
/// ops to SPIR-V ops.
|
||||||
|
void populateControlFlowToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
|
||||||
|
RewritePatternSet &patterns);
|
||||||
|
} // namespace cf
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // MLIR_CONVERSION_CONTROLFLOWTOSPIRV_CONTROLFLOWTOSPIRV_H
|
|
@ -17,6 +17,8 @@
|
||||||
#include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h"
|
#include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h"
|
||||||
#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
|
#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
|
||||||
#include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
|
#include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
|
||||||
|
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
|
||||||
|
#include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h"
|
||||||
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
|
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
|
||||||
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
|
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
|
||||||
#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
|
#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
|
||||||
|
@ -35,10 +37,10 @@
|
||||||
#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
|
#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
|
||||||
#include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
|
#include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
|
||||||
#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
|
#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
|
||||||
|
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
|
||||||
#include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h"
|
#include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h"
|
||||||
#include "mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h"
|
#include "mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h"
|
||||||
#include "mlir/Conversion/SCFToSPIRV/SCFToSPIRVPass.h"
|
#include "mlir/Conversion/SCFToSPIRV/SCFToSPIRVPass.h"
|
||||||
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
|
|
||||||
#include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.h"
|
#include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.h"
|
||||||
#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
|
#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
|
||||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
||||||
|
|
|
@ -181,6 +181,28 @@ def ConvertComplexToStandard : Pass<"convert-complex-to-standard", "FuncOp"> {
|
||||||
let dependentDialects = ["math::MathDialect"];
|
let dependentDialects = ["math::MathDialect"];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ControlFlowToLLVM
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def ConvertControlFlowToLLVM : Pass<"convert-cf-to-llvm", "ModuleOp"> {
|
||||||
|
let summary = "Convert ControlFlow operations to the LLVM dialect";
|
||||||
|
let description = [{
|
||||||
|
Convert ControlFlow operations into LLVM IR dialect operations.
|
||||||
|
|
||||||
|
If other operations are present and their results are required by the LLVM
|
||||||
|
IR dialect operations, the pass will fail. Any LLVM IR operations or types
|
||||||
|
already present in the IR will be kept as is.
|
||||||
|
}];
|
||||||
|
let constructor = "mlir::cf::createConvertControlFlowToLLVMPass()";
|
||||||
|
let dependentDialects = ["LLVM::LLVMDialect"];
|
||||||
|
let options = [
|
||||||
|
Option<"indexBitwidth", "index-bitwidth", "unsigned",
|
||||||
|
/*default=kDeriveIndexBitwidthFromDataLayout*/"0",
|
||||||
|
"Bitwidth of the index type, 0 to use size of machine word">,
|
||||||
|
];
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// GPUCommon
|
// GPUCommon
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -460,6 +482,17 @@ def ReconcileUnrealizedCasts : Pass<"reconcile-unrealized-casts"> {
|
||||||
let constructor = "mlir::createReconcileUnrealizedCastsPass()";
|
let constructor = "mlir::createReconcileUnrealizedCastsPass()";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// SCFToControlFlow
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def SCFToControlFlow : Pass<"convert-scf-to-cf"> {
|
||||||
|
let summary = "Convert SCF dialect to ControlFlow dialect, replacing structured"
|
||||||
|
" control flow with a CFG";
|
||||||
|
let constructor = "mlir::createConvertSCFToCFPass()";
|
||||||
|
let dependentDialects = ["cf::ControlFlowDialect"];
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// SCFToOpenMP
|
// SCFToOpenMP
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -488,17 +521,6 @@ def SCFToSPIRV : Pass<"convert-scf-to-spirv", "ModuleOp"> {
|
||||||
let dependentDialects = ["spirv::SPIRVDialect"];
|
let dependentDialects = ["spirv::SPIRVDialect"];
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// SCFToStandard
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
def SCFToStandard : Pass<"convert-scf-to-std"> {
|
|
||||||
let summary = "Convert SCF dialect to Standard dialect, replacing structured"
|
|
||||||
" control flow with a CFG";
|
|
||||||
let constructor = "mlir::createLowerToCFGPass()";
|
|
||||||
let dependentDialects = ["StandardOpsDialect"];
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// SCFToGPU
|
// SCFToGPU
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -547,7 +569,7 @@ def ConvertShapeConstraints: Pass<"convert-shape-constraints", "FuncOp"> {
|
||||||
computation lowering.
|
computation lowering.
|
||||||
}];
|
}];
|
||||||
let constructor = "mlir::createConvertShapeConstraintsPass()";
|
let constructor = "mlir::createConvertShapeConstraintsPass()";
|
||||||
let dependentDialects = ["StandardOpsDialect", "scf::SCFDialect"];
|
let dependentDialects = ["cf::ControlFlowDialect", "scf::SCFDialect"];
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -0,0 +1,28 @@
|
||||||
|
//===- ConvertSCFToControlFlow.h - Pass entrypoint --------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef MLIR_CONVERSION_SCFTOCONTROLFLOW_SCFTOCONTROLFLOW_H_
|
||||||
|
#define MLIR_CONVERSION_SCFTOCONTROLFLOW_SCFTOCONTROLFLOW_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
class Pass;
|
||||||
|
class RewritePatternSet;
|
||||||
|
|
||||||
|
/// Collect a set of patterns to convert SCF operations to CFG branch-based
|
||||||
|
/// operations within the ControlFlow dialect.
|
||||||
|
void populateSCFToControlFlowConversionPatterns(RewritePatternSet &patterns);
|
||||||
|
|
||||||
|
/// Creates a pass to convert SCF operations to CFG branch-based operation in
|
||||||
|
/// the ControlFlow dialect.
|
||||||
|
std::unique_ptr<Pass> createConvertSCFToCFPass();
|
||||||
|
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // MLIR_CONVERSION_SCFTOCONTROLFLOW_SCFTOCONTROLFLOW_H_
|
|
@ -1,31 +0,0 @@
|
||||||
//===- ConvertSCFToStandard.h - Pass entrypoint -----------------*- C++ -*-===//
|
|
||||||
//
|
|
||||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
||||||
// See https://llvm.org/LICENSE.txt for license information.
|
|
||||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
||||||
//
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
#ifndef MLIR_CONVERSION_SCFTOSTANDARD_SCFTOSTANDARD_H_
|
|
||||||
#define MLIR_CONVERSION_SCFTOSTANDARD_SCFTOSTANDARD_H_
|
|
||||||
|
|
||||||
#include <memory>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
namespace mlir {
|
|
||||||
struct LogicalResult;
|
|
||||||
class Pass;
|
|
||||||
|
|
||||||
class RewritePatternSet;
|
|
||||||
|
|
||||||
/// Collect a set of patterns to lower from scf.for, scf.if, and
|
|
||||||
/// loop.terminator to CFG operations within the Standard dialect, in particular
|
|
||||||
/// convert structured control flow into CFG branch-based control flow.
|
|
||||||
void populateLoopToStdConversionPatterns(RewritePatternSet &patterns);
|
|
||||||
|
|
||||||
/// Creates a pass to convert scf.for, scf.if and loop.terminator ops to CFG.
|
|
||||||
std::unique_ptr<Pass> createLowerToCFGPass();
|
|
||||||
|
|
||||||
} // namespace mlir
|
|
||||||
|
|
||||||
#endif // MLIR_CONVERSION_SCFTOSTANDARD_SCFTOSTANDARD_H_
|
|
|
@ -26,9 +26,9 @@ def BufferDeallocation : Pass<"buffer-deallocation", "FuncOp"> {
|
||||||
#map0 = affine_map<(d0) -> (d0)>
|
#map0 = affine_map<(d0) -> (d0)>
|
||||||
module {
|
module {
|
||||||
func @condBranch(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
|
func @condBranch(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
|
||||||
cond_br %arg0, ^bb1, ^bb2
|
cf.cond_br %arg0, ^bb1, ^bb2
|
||||||
^bb1:
|
^bb1:
|
||||||
br ^bb3(%arg1 : memref<2xf32>)
|
cf.br ^bb3(%arg1 : memref<2xf32>)
|
||||||
^bb2:
|
^bb2:
|
||||||
%0 = memref.alloc() : memref<2xf32>
|
%0 = memref.alloc() : memref<2xf32>
|
||||||
linalg.generic {
|
linalg.generic {
|
||||||
|
@ -40,7 +40,7 @@ def BufferDeallocation : Pass<"buffer-deallocation", "FuncOp"> {
|
||||||
%tmp1 = exp %gen1_arg0 : f32
|
%tmp1 = exp %gen1_arg0 : f32
|
||||||
linalg.yield %tmp1 : f32
|
linalg.yield %tmp1 : f32
|
||||||
}: memref<2xf32>, memref<2xf32>
|
}: memref<2xf32>, memref<2xf32>
|
||||||
br ^bb3(%0 : memref<2xf32>)
|
cf.br ^bb3(%0 : memref<2xf32>)
|
||||||
^bb3(%1: memref<2xf32>):
|
^bb3(%1: memref<2xf32>):
|
||||||
"memref.copy"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
|
"memref.copy"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
|
||||||
return
|
return
|
||||||
|
@ -55,11 +55,11 @@ def BufferDeallocation : Pass<"buffer-deallocation", "FuncOp"> {
|
||||||
#map0 = affine_map<(d0) -> (d0)>
|
#map0 = affine_map<(d0) -> (d0)>
|
||||||
module {
|
module {
|
||||||
func @condBranch(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
|
func @condBranch(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
|
||||||
cond_br %arg0, ^bb1, ^bb2
|
cf.cond_br %arg0, ^bb1, ^bb2
|
||||||
^bb1: // pred: ^bb0
|
^bb1: // pred: ^bb0
|
||||||
%0 = memref.alloc() : memref<2xf32>
|
%0 = memref.alloc() : memref<2xf32>
|
||||||
memref.copy(%arg1, %0) : memref<2xf32>, memref<2xf32>
|
memref.copy(%arg1, %0) : memref<2xf32>, memref<2xf32>
|
||||||
br ^bb3(%0 : memref<2xf32>)
|
cf.br ^bb3(%0 : memref<2xf32>)
|
||||||
^bb2: // pred: ^bb0
|
^bb2: // pred: ^bb0
|
||||||
%1 = memref.alloc() : memref<2xf32>
|
%1 = memref.alloc() : memref<2xf32>
|
||||||
linalg.generic {
|
linalg.generic {
|
||||||
|
@ -74,7 +74,7 @@ def BufferDeallocation : Pass<"buffer-deallocation", "FuncOp"> {
|
||||||
%2 = memref.alloc() : memref<2xf32>
|
%2 = memref.alloc() : memref<2xf32>
|
||||||
memref.copy(%1, %2) : memref<2xf32>, memref<2xf32>
|
memref.copy(%1, %2) : memref<2xf32>, memref<2xf32>
|
||||||
dealloc %1 : memref<2xf32>
|
dealloc %1 : memref<2xf32>
|
||||||
br ^bb3(%2 : memref<2xf32>)
|
cf.br ^bb3(%2 : memref<2xf32>)
|
||||||
^bb3(%3: memref<2xf32>): // 2 preds: ^bb1, ^bb2
|
^bb3(%3: memref<2xf32>): // 2 preds: ^bb1, ^bb2
|
||||||
memref.copy(%3, %arg2) : memref<2xf32>, memref<2xf32>
|
memref.copy(%3, %arg2) : memref<2xf32>, memref<2xf32>
|
||||||
dealloc %3 : memref<2xf32>
|
dealloc %3 : memref<2xf32>
|
||||||
|
|
|
@ -6,6 +6,7 @@ add_subdirectory(ArmSVE)
|
||||||
add_subdirectory(AMX)
|
add_subdirectory(AMX)
|
||||||
add_subdirectory(Bufferization)
|
add_subdirectory(Bufferization)
|
||||||
add_subdirectory(Complex)
|
add_subdirectory(Complex)
|
||||||
|
add_subdirectory(ControlFlow)
|
||||||
add_subdirectory(DLTI)
|
add_subdirectory(DLTI)
|
||||||
add_subdirectory(EmitC)
|
add_subdirectory(EmitC)
|
||||||
add_subdirectory(GPU)
|
add_subdirectory(GPU)
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
add_subdirectory(IR)
|
|
@ -0,0 +1,2 @@
|
||||||
|
add_mlir_dialect(ControlFlowOps cf ControlFlowOps)
|
||||||
|
add_mlir_doc(ControlFlowOps ControlFlowDialect Dialects/ -gen-dialect-doc)
|
|
@ -0,0 +1,21 @@
|
||||||
|
//===- ControlFlow.h - ControlFlow Dialect ----------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// This file defines the ControlFlow dialect.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef MLIR_DIALECT_CONTROLFLOW_IR_CONTROLFLOW_H
|
||||||
|
#define MLIR_DIALECT_CONTROLFLOW_IR_CONTROLFLOW_H
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||||
|
#include "mlir/IR/Dialect.h"
|
||||||
|
|
||||||
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOpsDialect.h.inc"
|
||||||
|
|
||||||
|
#endif // MLIR_DIALECT_CONTROLFLOW_IR_CONTROLFLOW_H
|
|
@ -0,0 +1,30 @@
|
||||||
|
//===- ControlFlowOps.h - ControlFlow Operations ----------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// This file defines the operations of the ControlFlow dialect.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef MLIR_DIALECT_CONTROLFLOW_IR_CONTROLFLOWOPS_H
|
||||||
|
#define MLIR_DIALECT_CONTROLFLOW_IR_CONTROLFLOWOPS_H
|
||||||
|
|
||||||
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
|
||||||
|
#include "mlir/IR/Builders.h"
|
||||||
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
#include "mlir/IR/OpImplementation.h"
|
||||||
|
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
||||||
|
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
class PatternRewriter;
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#define GET_OP_CLASSES
|
||||||
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h.inc"
|
||||||
|
|
||||||
|
#endif // MLIR_DIALECT_CONTROLFLOW_IR_CONTROLFLOWOPS_H
|
|
@ -0,0 +1,313 @@
|
||||||
|
//===- ControlFlowOps.td - ControlFlow operations ----------*- tablegen -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// This file contains definitions for the operations within the ControlFlow
|
||||||
|
// dialect.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef STANDARD_OPS
|
||||||
|
#define STANDARD_OPS
|
||||||
|
|
||||||
|
include "mlir/IR/OpAsmInterface.td"
|
||||||
|
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||||
|
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||||
|
|
||||||
|
def ControlFlow_Dialect : Dialect {
|
||||||
|
let name = "cf";
|
||||||
|
let cppNamespace = "::mlir::cf";
|
||||||
|
let dependentDialects = ["arith::ArithmeticDialect"];
|
||||||
|
let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed;
|
||||||
|
let description = [{
|
||||||
|
This dialect contains low-level, i.e. non-region based, control flow
|
||||||
|
constructs. These constructs generally represent control flow directly
|
||||||
|
on SSA blocks of a control flow graph.
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
class CF_Op<string mnemonic, list<Trait> traits = []> :
|
||||||
|
Op<ControlFlow_Dialect, mnemonic, traits>;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// AssertOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def AssertOp : CF_Op<"assert"> {
|
||||||
|
let summary = "Assert operation with message attribute";
|
||||||
|
let description = [{
|
||||||
|
Assert operation with single boolean operand and an error message attribute.
|
||||||
|
If the argument is `true` this operation has no effect. Otherwise, the
|
||||||
|
program execution will abort. The provided error message may be used by a
|
||||||
|
runtime to propagate the error to the user.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```mlir
|
||||||
|
assert %b, "Expected ... to be true"
|
||||||
|
```
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins I1:$arg, StrAttr:$msg);
|
||||||
|
|
||||||
|
let assemblyFormat = "$arg `,` $msg attr-dict";
|
||||||
|
let hasCanonicalizeMethod = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// BranchOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def BranchOp : CF_Op<"br", [
|
||||||
|
DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
|
||||||
|
NoSideEffect, Terminator
|
||||||
|
]> {
|
||||||
|
let summary = "branch operation";
|
||||||
|
let description = [{
|
||||||
|
The `cf.br` operation represents a direct branch operation to a given
|
||||||
|
block. The operands of this operation are forwarded to the successor block,
|
||||||
|
and the number and type of the operands must match the arguments of the
|
||||||
|
target block.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```mlir
|
||||||
|
^bb2:
|
||||||
|
%2 = call @someFn()
|
||||||
|
cf.br ^bb3(%2 : tensor<*xf32>)
|
||||||
|
^bb3(%3: tensor<*xf32>):
|
||||||
|
```
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins Variadic<AnyType>:$destOperands);
|
||||||
|
let successors = (successor AnySuccessor:$dest);
|
||||||
|
|
||||||
|
let builders = [
|
||||||
|
OpBuilder<(ins "Block *":$dest,
|
||||||
|
CArg<"ValueRange", "{}">:$destOperands), [{
|
||||||
|
$_state.addSuccessors(dest);
|
||||||
|
$_state.addOperands(destOperands);
|
||||||
|
}]>];
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
void setDest(Block *block);
|
||||||
|
|
||||||
|
/// Erase the operand at 'index' from the operand list.
|
||||||
|
void eraseOperand(unsigned index);
|
||||||
|
}];
|
||||||
|
|
||||||
|
let hasCanonicalizeMethod = 1;
|
||||||
|
let assemblyFormat = [{
|
||||||
|
$dest (`(` $destOperands^ `:` type($destOperands) `)`)? attr-dict
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// CondBranchOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def CondBranchOp : CF_Op<"cond_br",
|
||||||
|
[AttrSizedOperandSegments,
|
||||||
|
DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
|
||||||
|
NoSideEffect, Terminator]> {
|
||||||
|
let summary = "conditional branch operation";
|
||||||
|
let description = [{
|
||||||
|
The `cond_br` terminator operation represents a conditional branch on a
|
||||||
|
boolean (1-bit integer) value. If the bit is set, then the first destination
|
||||||
|
is jumped to; if it is false, the second destination is chosen. The count
|
||||||
|
and types of operands must align with the arguments in the corresponding
|
||||||
|
target blocks.
|
||||||
|
|
||||||
|
The MLIR conditional branch operation is not allowed to target the entry
|
||||||
|
block for a region. The two destinations of the conditional branch operation
|
||||||
|
are allowed to be the same.
|
||||||
|
|
||||||
|
The following example illustrates a function with a conditional branch
|
||||||
|
operation that targets the same block.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```mlir
|
||||||
|
func @select(%a: i32, %b: i32, %flag: i1) -> i32 {
|
||||||
|
// Both targets are the same, operands differ
|
||||||
|
cond_br %flag, ^bb1(%a : i32), ^bb1(%b : i32)
|
||||||
|
|
||||||
|
^bb1(%x : i32) :
|
||||||
|
return %x : i32
|
||||||
|
}
|
||||||
|
```
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins I1:$condition,
|
||||||
|
Variadic<AnyType>:$trueDestOperands,
|
||||||
|
Variadic<AnyType>:$falseDestOperands);
|
||||||
|
let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest);
|
||||||
|
|
||||||
|
let builders = [
|
||||||
|
OpBuilder<(ins "Value":$condition, "Block *":$trueDest,
|
||||||
|
"ValueRange":$trueOperands, "Block *":$falseDest,
|
||||||
|
"ValueRange":$falseOperands), [{
|
||||||
|
build($_builder, $_state, condition, trueOperands, falseOperands, trueDest,
|
||||||
|
falseDest);
|
||||||
|
}]>,
|
||||||
|
OpBuilder<(ins "Value":$condition, "Block *":$trueDest,
|
||||||
|
"Block *":$falseDest, CArg<"ValueRange", "{}">:$falseOperands), [{
|
||||||
|
build($_builder, $_state, condition, trueDest, ValueRange(), falseDest,
|
||||||
|
falseOperands);
|
||||||
|
}]>];
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
// These are the indices into the dests list.
|
||||||
|
enum { trueIndex = 0, falseIndex = 1 };
|
||||||
|
|
||||||
|
// Accessors for operands to the 'true' destination.
|
||||||
|
Value getTrueOperand(unsigned idx) {
|
||||||
|
assert(idx < getNumTrueOperands());
|
||||||
|
return getOperand(getTrueDestOperandIndex() + idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
void setTrueOperand(unsigned idx, Value value) {
|
||||||
|
assert(idx < getNumTrueOperands());
|
||||||
|
setOperand(getTrueDestOperandIndex() + idx, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned getNumTrueOperands() { return getTrueOperands().size(); }
|
||||||
|
|
||||||
|
/// Erase the operand at 'index' from the true operand list.
|
||||||
|
void eraseTrueOperand(unsigned index) {
|
||||||
|
getTrueDestOperandsMutable().erase(index);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Accessors for operands to the 'false' destination.
|
||||||
|
Value getFalseOperand(unsigned idx) {
|
||||||
|
assert(idx < getNumFalseOperands());
|
||||||
|
return getOperand(getFalseDestOperandIndex() + idx);
|
||||||
|
}
|
||||||
|
void setFalseOperand(unsigned idx, Value value) {
|
||||||
|
assert(idx < getNumFalseOperands());
|
||||||
|
setOperand(getFalseDestOperandIndex() + idx, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
operand_range getTrueOperands() { return getTrueDestOperands(); }
|
||||||
|
operand_range getFalseOperands() { return getFalseDestOperands(); }
|
||||||
|
|
||||||
|
unsigned getNumFalseOperands() { return getFalseOperands().size(); }
|
||||||
|
|
||||||
|
/// Erase the operand at 'index' from the false operand list.
|
||||||
|
void eraseFalseOperand(unsigned index) {
|
||||||
|
getFalseDestOperandsMutable().erase(index);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
/// Get the index of the first true destination operand.
|
||||||
|
unsigned getTrueDestOperandIndex() { return 1; }
|
||||||
|
|
||||||
|
/// Get the index of the first false destination operand.
|
||||||
|
unsigned getFalseDestOperandIndex() {
|
||||||
|
return getTrueDestOperandIndex() + getNumTrueOperands();
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
|
||||||
|
let hasCanonicalizer = 1;
|
||||||
|
let assemblyFormat = [{
|
||||||
|
$condition `,`
|
||||||
|
$trueDest (`(` $trueDestOperands^ `:` type($trueDestOperands) `)`)? `,`
|
||||||
|
$falseDest (`(` $falseDestOperands^ `:` type($falseDestOperands) `)`)?
|
||||||
|
attr-dict
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// SwitchOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def SwitchOp : CF_Op<"switch",
|
||||||
|
[AttrSizedOperandSegments,
|
||||||
|
DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
|
||||||
|
NoSideEffect, Terminator]> {
|
||||||
|
let summary = "switch operation";
|
||||||
|
let description = [{
|
||||||
|
The `switch` terminator operation represents a switch on a signless integer
|
||||||
|
value. If the flag matches one of the specified cases, then the
|
||||||
|
corresponding destination is jumped to. If the flag does not match any of
|
||||||
|
the cases, the default destination is jumped to. The count and types of
|
||||||
|
operands must align with the arguments in the corresponding target blocks.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```mlir
|
||||||
|
switch %flag : i32, [
|
||||||
|
default: ^bb1(%a : i32),
|
||||||
|
42: ^bb1(%b : i32),
|
||||||
|
43: ^bb3(%c : i32)
|
||||||
|
]
|
||||||
|
```
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
AnyInteger:$flag,
|
||||||
|
Variadic<AnyType>:$defaultOperands,
|
||||||
|
VariadicOfVariadic<AnyType, "case_operand_segments">:$caseOperands,
|
||||||
|
OptionalAttr<AnyIntElementsAttr>:$case_values,
|
||||||
|
I32ElementsAttr:$case_operand_segments
|
||||||
|
);
|
||||||
|
let successors = (successor
|
||||||
|
AnySuccessor:$defaultDestination,
|
||||||
|
VariadicSuccessor<AnySuccessor>:$caseDestinations
|
||||||
|
);
|
||||||
|
let builders = [
|
||||||
|
OpBuilder<(ins "Value":$flag,
|
||||||
|
"Block *":$defaultDestination,
|
||||||
|
"ValueRange":$defaultOperands,
|
||||||
|
CArg<"ArrayRef<APInt>", "{}">:$caseValues,
|
||||||
|
CArg<"BlockRange", "{}">:$caseDestinations,
|
||||||
|
CArg<"ArrayRef<ValueRange>", "{}">:$caseOperands)>,
|
||||||
|
OpBuilder<(ins "Value":$flag,
|
||||||
|
"Block *":$defaultDestination,
|
||||||
|
"ValueRange":$defaultOperands,
|
||||||
|
CArg<"ArrayRef<int32_t>", "{}">:$caseValues,
|
||||||
|
CArg<"BlockRange", "{}">:$caseDestinations,
|
||||||
|
CArg<"ArrayRef<ValueRange>", "{}">:$caseOperands)>,
|
||||||
|
OpBuilder<(ins "Value":$flag,
|
||||||
|
"Block *":$defaultDestination,
|
||||||
|
"ValueRange":$defaultOperands,
|
||||||
|
CArg<"DenseIntElementsAttr", "{}">:$caseValues,
|
||||||
|
CArg<"BlockRange", "{}">:$caseDestinations,
|
||||||
|
CArg<"ArrayRef<ValueRange>", "{}">:$caseOperands)>
|
||||||
|
];
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
$flag `:` type($flag) `,` `[` `\n`
|
||||||
|
custom<SwitchOpCases>(ref(type($flag)),$defaultDestination,
|
||||||
|
$defaultOperands,
|
||||||
|
type($defaultOperands),
|
||||||
|
$case_values,
|
||||||
|
$caseDestinations,
|
||||||
|
$caseOperands,
|
||||||
|
type($caseOperands))
|
||||||
|
`]`
|
||||||
|
attr-dict
|
||||||
|
}];
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
/// Return the operands for the case destination block at the given index.
|
||||||
|
OperandRange getCaseOperands(unsigned index) {
|
||||||
|
return getCaseOperands()[index];
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return a mutable range of operands for the case destination block at the
|
||||||
|
/// given index.
|
||||||
|
MutableOperandRange getCaseOperandsMutable(unsigned index) {
|
||||||
|
return getCaseOperandsMutable()[index];
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
|
||||||
|
let hasCanonicalizer = 1;
|
||||||
|
let hasVerifier = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // STANDARD_OPS
|
|
@ -84,15 +84,15 @@ def ExecuteRegionOp : SCF_Op<"execute_region"> {
|
||||||
affine.for %i = 0 to 100 {
|
affine.for %i = 0 to 100 {
|
||||||
"foo"() : () -> ()
|
"foo"() : () -> ()
|
||||||
%v = scf.execute_region -> i64 {
|
%v = scf.execute_region -> i64 {
|
||||||
cond_br %cond, ^bb1, ^bb2
|
cf.cond_br %cond, ^bb1, ^bb2
|
||||||
|
|
||||||
^bb1:
|
^bb1:
|
||||||
%c1 = arith.constant 1 : i64
|
%c1 = arith.constant 1 : i64
|
||||||
br ^bb3(%c1 : i64)
|
cf.br ^bb3(%c1 : i64)
|
||||||
|
|
||||||
^bb2:
|
^bb2:
|
||||||
%c2 = arith.constant 2 : i64
|
%c2 = arith.constant 2 : i64
|
||||||
br ^bb3(%c2 : i64)
|
cf.br ^bb3(%c2 : i64)
|
||||||
|
|
||||||
^bb3(%x : i64):
|
^bb3(%x : i64):
|
||||||
scf.yield %x : i64
|
scf.yield %x : i64
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
#ifndef MLIR_DIALECT_STANDARDOPS_IR_OPS_H
|
#ifndef MLIR_DIALECT_STANDARDOPS_IR_OPS_H
|
||||||
#define MLIR_DIALECT_STANDARDOPS_IR_OPS_H
|
#define MLIR_DIALECT_STANDARDOPS_IR_OPS_H
|
||||||
|
|
||||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
|
||||||
#include "mlir/IR/Builders.h"
|
#include "mlir/IR/Builders.h"
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
#include "mlir/IR/Dialect.h"
|
#include "mlir/IR/Dialect.h"
|
||||||
|
@ -24,7 +24,6 @@
|
||||||
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
||||||
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||||
#include "mlir/Interfaces/VectorInterfaces.h"
|
|
||||||
|
|
||||||
// Pull in all enum type definitions and utility function declarations.
|
// Pull in all enum type definitions and utility function declarations.
|
||||||
#include "mlir/Dialect/StandardOps/IR/OpsEnums.h.inc"
|
#include "mlir/Dialect/StandardOps/IR/OpsEnums.h.inc"
|
||||||
|
|
|
@ -20,12 +20,11 @@ include "mlir/Interfaces/CastInterfaces.td"
|
||||||
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||||
include "mlir/Interfaces/InferTypeOpInterface.td"
|
include "mlir/Interfaces/InferTypeOpInterface.td"
|
||||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||||
include "mlir/Interfaces/VectorInterfaces.td"
|
|
||||||
|
|
||||||
def StandardOps_Dialect : Dialect {
|
def StandardOps_Dialect : Dialect {
|
||||||
let name = "std";
|
let name = "std";
|
||||||
let cppNamespace = "::mlir";
|
let cppNamespace = "::mlir";
|
||||||
let dependentDialects = ["arith::ArithmeticDialect"];
|
let dependentDialects = ["cf::ControlFlowDialect"];
|
||||||
let hasConstantMaterializer = 1;
|
let hasConstantMaterializer = 1;
|
||||||
let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed;
|
let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed;
|
||||||
}
|
}
|
||||||
|
@ -42,78 +41,6 @@ class Std_Op<string mnemonic, list<Trait> traits = []> :
|
||||||
let parser = [{ return ::parse$cppClass(parser, result); }];
|
let parser = [{ return ::parse$cppClass(parser, result); }];
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// AssertOp
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
def AssertOp : Std_Op<"assert"> {
|
|
||||||
let summary = "Assert operation with message attribute";
|
|
||||||
let description = [{
|
|
||||||
Assert operation with single boolean operand and an error message attribute.
|
|
||||||
If the argument is `true` this operation has no effect. Otherwise, the
|
|
||||||
program execution will abort. The provided error message may be used by a
|
|
||||||
runtime to propagate the error to the user.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```mlir
|
|
||||||
assert %b, "Expected ... to be true"
|
|
||||||
```
|
|
||||||
}];
|
|
||||||
|
|
||||||
let arguments = (ins I1:$arg, StrAttr:$msg);
|
|
||||||
|
|
||||||
let assemblyFormat = "$arg `,` $msg attr-dict";
|
|
||||||
let hasCanonicalizeMethod = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// BranchOp
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
def BranchOp : Std_Op<"br",
|
|
||||||
[DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
|
|
||||||
NoSideEffect, Terminator]> {
|
|
||||||
let summary = "branch operation";
|
|
||||||
let description = [{
|
|
||||||
The `br` operation represents a branch operation in a function.
|
|
||||||
The operation takes variable number of operands and produces no results.
|
|
||||||
The operand number and types for each successor must match the arguments of
|
|
||||||
the block successor.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```mlir
|
|
||||||
^bb2:
|
|
||||||
%2 = call @someFn()
|
|
||||||
br ^bb3(%2 : tensor<*xf32>)
|
|
||||||
^bb3(%3: tensor<*xf32>):
|
|
||||||
```
|
|
||||||
}];
|
|
||||||
|
|
||||||
let arguments = (ins Variadic<AnyType>:$destOperands);
|
|
||||||
let successors = (successor AnySuccessor:$dest);
|
|
||||||
|
|
||||||
let builders = [
|
|
||||||
OpBuilder<(ins "Block *":$dest,
|
|
||||||
CArg<"ValueRange", "{}">:$destOperands), [{
|
|
||||||
$_state.addSuccessors(dest);
|
|
||||||
$_state.addOperands(destOperands);
|
|
||||||
}]>];
|
|
||||||
|
|
||||||
let extraClassDeclaration = [{
|
|
||||||
void setDest(Block *block);
|
|
||||||
|
|
||||||
/// Erase the operand at 'index' from the operand list.
|
|
||||||
void eraseOperand(unsigned index);
|
|
||||||
}];
|
|
||||||
|
|
||||||
let hasCanonicalizeMethod = 1;
|
|
||||||
let assemblyFormat = [{
|
|
||||||
$dest (`(` $destOperands^ `:` type($destOperands) `)`)? attr-dict
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// CallOp
|
// CallOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -246,121 +173,6 @@ def CallIndirectOp : Std_Op<"call_indirect", [
|
||||||
"$callee `(` $callee_operands `)` attr-dict `:` type($callee)";
|
"$callee `(` $callee_operands `)` attr-dict `:` type($callee)";
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// CondBranchOp
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
def CondBranchOp : Std_Op<"cond_br",
|
|
||||||
[AttrSizedOperandSegments,
|
|
||||||
DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
|
|
||||||
NoSideEffect, Terminator]> {
|
|
||||||
let summary = "conditional branch operation";
|
|
||||||
let description = [{
|
|
||||||
The `cond_br` terminator operation represents a conditional branch on a
|
|
||||||
boolean (1-bit integer) value. If the bit is set, then the first destination
|
|
||||||
is jumped to; if it is false, the second destination is chosen. The count
|
|
||||||
and types of operands must align with the arguments in the corresponding
|
|
||||||
target blocks.
|
|
||||||
|
|
||||||
The MLIR conditional branch operation is not allowed to target the entry
|
|
||||||
block for a region. The two destinations of the conditional branch operation
|
|
||||||
are allowed to be the same.
|
|
||||||
|
|
||||||
The following example illustrates a function with a conditional branch
|
|
||||||
operation that targets the same block.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```mlir
|
|
||||||
func @select(%a: i32, %b: i32, %flag: i1) -> i32 {
|
|
||||||
// Both targets are the same, operands differ
|
|
||||||
cond_br %flag, ^bb1(%a : i32), ^bb1(%b : i32)
|
|
||||||
|
|
||||||
^bb1(%x : i32) :
|
|
||||||
return %x : i32
|
|
||||||
}
|
|
||||||
```
|
|
||||||
}];
|
|
||||||
|
|
||||||
let arguments = (ins I1:$condition,
|
|
||||||
Variadic<AnyType>:$trueDestOperands,
|
|
||||||
Variadic<AnyType>:$falseDestOperands);
|
|
||||||
let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest);
|
|
||||||
|
|
||||||
let builders = [
|
|
||||||
OpBuilder<(ins "Value":$condition, "Block *":$trueDest,
|
|
||||||
"ValueRange":$trueOperands, "Block *":$falseDest,
|
|
||||||
"ValueRange":$falseOperands), [{
|
|
||||||
build($_builder, $_state, condition, trueOperands, falseOperands, trueDest,
|
|
||||||
falseDest);
|
|
||||||
}]>,
|
|
||||||
OpBuilder<(ins "Value":$condition, "Block *":$trueDest,
|
|
||||||
"Block *":$falseDest, CArg<"ValueRange", "{}">:$falseOperands), [{
|
|
||||||
build($_builder, $_state, condition, trueDest, ValueRange(), falseDest,
|
|
||||||
falseOperands);
|
|
||||||
}]>];
|
|
||||||
|
|
||||||
let extraClassDeclaration = [{
|
|
||||||
// These are the indices into the dests list.
|
|
||||||
enum { trueIndex = 0, falseIndex = 1 };
|
|
||||||
|
|
||||||
// Accessors for operands to the 'true' destination.
|
|
||||||
Value getTrueOperand(unsigned idx) {
|
|
||||||
assert(idx < getNumTrueOperands());
|
|
||||||
return getOperand(getTrueDestOperandIndex() + idx);
|
|
||||||
}
|
|
||||||
|
|
||||||
void setTrueOperand(unsigned idx, Value value) {
|
|
||||||
assert(idx < getNumTrueOperands());
|
|
||||||
setOperand(getTrueDestOperandIndex() + idx, value);
|
|
||||||
}
|
|
||||||
|
|
||||||
unsigned getNumTrueOperands() { return getTrueOperands().size(); }
|
|
||||||
|
|
||||||
/// Erase the operand at 'index' from the true operand list.
|
|
||||||
void eraseTrueOperand(unsigned index) {
|
|
||||||
getTrueDestOperandsMutable().erase(index);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Accessors for operands to the 'false' destination.
|
|
||||||
Value getFalseOperand(unsigned idx) {
|
|
||||||
assert(idx < getNumFalseOperands());
|
|
||||||
return getOperand(getFalseDestOperandIndex() + idx);
|
|
||||||
}
|
|
||||||
void setFalseOperand(unsigned idx, Value value) {
|
|
||||||
assert(idx < getNumFalseOperands());
|
|
||||||
setOperand(getFalseDestOperandIndex() + idx, value);
|
|
||||||
}
|
|
||||||
|
|
||||||
operand_range getTrueOperands() { return getTrueDestOperands(); }
|
|
||||||
operand_range getFalseOperands() { return getFalseDestOperands(); }
|
|
||||||
|
|
||||||
unsigned getNumFalseOperands() { return getFalseOperands().size(); }
|
|
||||||
|
|
||||||
/// Erase the operand at 'index' from the false operand list.
|
|
||||||
void eraseFalseOperand(unsigned index) {
|
|
||||||
getFalseDestOperandsMutable().erase(index);
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
/// Get the index of the first true destination operand.
|
|
||||||
unsigned getTrueDestOperandIndex() { return 1; }
|
|
||||||
|
|
||||||
/// Get the index of the first false destination operand.
|
|
||||||
unsigned getFalseDestOperandIndex() {
|
|
||||||
return getTrueDestOperandIndex() + getNumTrueOperands();
|
|
||||||
}
|
|
||||||
}];
|
|
||||||
|
|
||||||
let hasCanonicalizer = 1;
|
|
||||||
let assemblyFormat = [{
|
|
||||||
$condition `,`
|
|
||||||
$trueDest (`(` $trueDestOperands^ `:` type($trueDestOperands) `)`)? `,`
|
|
||||||
$falseDest (`(` $falseDestOperands^ `:` type($falseDestOperands) `)`)?
|
|
||||||
attr-dict
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// ConstantOp
|
// ConstantOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -443,93 +255,4 @@ def ReturnOp : Std_Op<"return", [NoSideEffect, HasParent<"FuncOp">,
|
||||||
let hasVerifier = 1;
|
let hasVerifier = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// SwitchOp
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
def SwitchOp : Std_Op<"switch",
|
|
||||||
[AttrSizedOperandSegments,
|
|
||||||
DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
|
|
||||||
NoSideEffect, Terminator]> {
|
|
||||||
let summary = "switch operation";
|
|
||||||
let description = [{
|
|
||||||
The `switch` terminator operation represents a switch on a signless integer
|
|
||||||
value. If the flag matches one of the specified cases, then the
|
|
||||||
corresponding destination is jumped to. If the flag does not match any of
|
|
||||||
the cases, the default destination is jumped to. The count and types of
|
|
||||||
operands must align with the arguments in the corresponding target blocks.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```mlir
|
|
||||||
switch %flag : i32, [
|
|
||||||
default: ^bb1(%a : i32),
|
|
||||||
42: ^bb1(%b : i32),
|
|
||||||
43: ^bb3(%c : i32)
|
|
||||||
]
|
|
||||||
```
|
|
||||||
}];
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
AnyInteger:$flag,
|
|
||||||
Variadic<AnyType>:$defaultOperands,
|
|
||||||
VariadicOfVariadic<AnyType, "case_operand_segments">:$caseOperands,
|
|
||||||
OptionalAttr<AnyIntElementsAttr>:$case_values,
|
|
||||||
I32ElementsAttr:$case_operand_segments
|
|
||||||
);
|
|
||||||
let successors = (successor
|
|
||||||
AnySuccessor:$defaultDestination,
|
|
||||||
VariadicSuccessor<AnySuccessor>:$caseDestinations
|
|
||||||
);
|
|
||||||
let builders = [
|
|
||||||
OpBuilder<(ins "Value":$flag,
|
|
||||||
"Block *":$defaultDestination,
|
|
||||||
"ValueRange":$defaultOperands,
|
|
||||||
CArg<"ArrayRef<APInt>", "{}">:$caseValues,
|
|
||||||
CArg<"BlockRange", "{}">:$caseDestinations,
|
|
||||||
CArg<"ArrayRef<ValueRange>", "{}">:$caseOperands)>,
|
|
||||||
OpBuilder<(ins "Value":$flag,
|
|
||||||
"Block *":$defaultDestination,
|
|
||||||
"ValueRange":$defaultOperands,
|
|
||||||
CArg<"ArrayRef<int32_t>", "{}">:$caseValues,
|
|
||||||
CArg<"BlockRange", "{}">:$caseDestinations,
|
|
||||||
CArg<"ArrayRef<ValueRange>", "{}">:$caseOperands)>,
|
|
||||||
OpBuilder<(ins "Value":$flag,
|
|
||||||
"Block *":$defaultDestination,
|
|
||||||
"ValueRange":$defaultOperands,
|
|
||||||
CArg<"DenseIntElementsAttr", "{}">:$caseValues,
|
|
||||||
CArg<"BlockRange", "{}">:$caseDestinations,
|
|
||||||
CArg<"ArrayRef<ValueRange>", "{}">:$caseOperands)>
|
|
||||||
];
|
|
||||||
|
|
||||||
let assemblyFormat = [{
|
|
||||||
$flag `:` type($flag) `,` `[` `\n`
|
|
||||||
custom<SwitchOpCases>(ref(type($flag)),$defaultDestination,
|
|
||||||
$defaultOperands,
|
|
||||||
type($defaultOperands),
|
|
||||||
$case_values,
|
|
||||||
$caseDestinations,
|
|
||||||
$caseOperands,
|
|
||||||
type($caseOperands))
|
|
||||||
`]`
|
|
||||||
attr-dict
|
|
||||||
}];
|
|
||||||
|
|
||||||
let extraClassDeclaration = [{
|
|
||||||
/// Return the operands for the case destination block at the given index.
|
|
||||||
OperandRange getCaseOperands(unsigned index) {
|
|
||||||
return getCaseOperands()[index];
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Return a mutable range of operands for the case destination block at the
|
|
||||||
/// given index.
|
|
||||||
MutableOperandRange getCaseOperandsMutable(unsigned index) {
|
|
||||||
return getCaseOperandsMutable()[index];
|
|
||||||
}
|
|
||||||
}];
|
|
||||||
|
|
||||||
let hasCanonicalizer = 1;
|
|
||||||
let hasVerifier = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif // STANDARD_OPS
|
#endif // STANDARD_OPS
|
||||||
|
|
|
@ -22,6 +22,7 @@
|
||||||
#include "mlir/Dialect/Async/IR/Async.h"
|
#include "mlir/Dialect/Async/IR/Async.h"
|
||||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||||
#include "mlir/Dialect/Complex/IR/Complex.h"
|
#include "mlir/Dialect/Complex/IR/Complex.h"
|
||||||
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
|
||||||
#include "mlir/Dialect/DLTI/DLTI.h"
|
#include "mlir/Dialect/DLTI/DLTI.h"
|
||||||
#include "mlir/Dialect/EmitC/IR/EmitC.h"
|
#include "mlir/Dialect/EmitC/IR/EmitC.h"
|
||||||
#include "mlir/Dialect/GPU/GPUDialect.h"
|
#include "mlir/Dialect/GPU/GPUDialect.h"
|
||||||
|
@ -61,6 +62,7 @@ inline void registerAllDialects(DialectRegistry ®istry) {
|
||||||
arm_neon::ArmNeonDialect,
|
arm_neon::ArmNeonDialect,
|
||||||
async::AsyncDialect,
|
async::AsyncDialect,
|
||||||
bufferization::BufferizationDialect,
|
bufferization::BufferizationDialect,
|
||||||
|
cf::ControlFlowDialect,
|
||||||
complex::ComplexDialect,
|
complex::ComplexDialect,
|
||||||
DLTIDialect,
|
DLTIDialect,
|
||||||
emitc::EmitCDialect,
|
emitc::EmitCDialect,
|
||||||
|
|
|
@ -6,6 +6,8 @@ add_subdirectory(AsyncToLLVM)
|
||||||
add_subdirectory(BufferizationToMemRef)
|
add_subdirectory(BufferizationToMemRef)
|
||||||
add_subdirectory(ComplexToLLVM)
|
add_subdirectory(ComplexToLLVM)
|
||||||
add_subdirectory(ComplexToStandard)
|
add_subdirectory(ComplexToStandard)
|
||||||
|
add_subdirectory(ControlFlowToLLVM)
|
||||||
|
add_subdirectory(ControlFlowToSPIRV)
|
||||||
add_subdirectory(GPUCommon)
|
add_subdirectory(GPUCommon)
|
||||||
add_subdirectory(GPUToNVVM)
|
add_subdirectory(GPUToNVVM)
|
||||||
add_subdirectory(GPUToROCDL)
|
add_subdirectory(GPUToROCDL)
|
||||||
|
@ -25,10 +27,10 @@ add_subdirectory(OpenACCToSCF)
|
||||||
add_subdirectory(OpenMPToLLVM)
|
add_subdirectory(OpenMPToLLVM)
|
||||||
add_subdirectory(PDLToPDLInterp)
|
add_subdirectory(PDLToPDLInterp)
|
||||||
add_subdirectory(ReconcileUnrealizedCasts)
|
add_subdirectory(ReconcileUnrealizedCasts)
|
||||||
|
add_subdirectory(SCFToControlFlow)
|
||||||
add_subdirectory(SCFToGPU)
|
add_subdirectory(SCFToGPU)
|
||||||
add_subdirectory(SCFToOpenMP)
|
add_subdirectory(SCFToOpenMP)
|
||||||
add_subdirectory(SCFToSPIRV)
|
add_subdirectory(SCFToSPIRV)
|
||||||
add_subdirectory(SCFToStandard)
|
|
||||||
add_subdirectory(ShapeToStandard)
|
add_subdirectory(ShapeToStandard)
|
||||||
add_subdirectory(SPIRVToLLVM)
|
add_subdirectory(SPIRVToLLVM)
|
||||||
add_subdirectory(StandardToLLVM)
|
add_subdirectory(StandardToLLVM)
|
||||||
|
|
|
@ -0,0 +1,21 @@
|
||||||
|
add_mlir_conversion_library(MLIRControlFlowToLLVM
|
||||||
|
ControlFlowToLLVM.cpp
|
||||||
|
|
||||||
|
ADDITIONAL_HEADER_DIRS
|
||||||
|
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ControlFlowToLLVM
|
||||||
|
|
||||||
|
DEPENDS
|
||||||
|
MLIRConversionPassIncGen
|
||||||
|
intrinsics_gen
|
||||||
|
|
||||||
|
LINK_COMPONENTS
|
||||||
|
Core
|
||||||
|
|
||||||
|
LINK_LIBS PUBLIC
|
||||||
|
MLIRAnalysis
|
||||||
|
MLIRControlFlow
|
||||||
|
MLIRLLVMCommonConversion
|
||||||
|
MLIRLLVMIR
|
||||||
|
MLIRPass
|
||||||
|
MLIRTransformUtils
|
||||||
|
)
|
|
@ -0,0 +1,148 @@
|
||||||
|
//===- ControlFlowToLLVM.cpp - ControlFlow to LLVM dialect conversion -----===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// This file implements a pass to convert MLIR standard and builtin dialects
|
||||||
|
// into the LLVM IR dialect.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
|
||||||
|
#include "../PassDetail.h"
|
||||||
|
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
|
||||||
|
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
||||||
|
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
|
||||||
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||||
|
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
|
||||||
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||||
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
#include <functional>
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
#define PASS_NAME "convert-cf-to-llvm"
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
/// Lower `std.assert`. The default lowering calls the `abort` function if the
|
||||||
|
/// assertion is violated and has no effect otherwise. The failure message is
|
||||||
|
/// ignored by the default lowering but should be propagated by any custom
|
||||||
|
/// lowering.
|
||||||
|
struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
|
||||||
|
using ConvertOpToLLVMPattern<cf::AssertOp>::ConvertOpToLLVMPattern;
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
|
||||||
|
// Insert the `abort` declaration if necessary.
|
||||||
|
auto module = op->getParentOfType<ModuleOp>();
|
||||||
|
auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
|
||||||
|
if (!abortFunc) {
|
||||||
|
OpBuilder::InsertionGuard guard(rewriter);
|
||||||
|
rewriter.setInsertionPointToStart(module.getBody());
|
||||||
|
auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {});
|
||||||
|
abortFunc = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(),
|
||||||
|
"abort", abortFuncTy);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Split block at `assert` operation.
|
||||||
|
Block *opBlock = rewriter.getInsertionBlock();
|
||||||
|
auto opPosition = rewriter.getInsertionPoint();
|
||||||
|
Block *continuationBlock = rewriter.splitBlock(opBlock, opPosition);
|
||||||
|
|
||||||
|
// Generate IR to call `abort`.
|
||||||
|
Block *failureBlock = rewriter.createBlock(opBlock->getParent());
|
||||||
|
rewriter.create<LLVM::CallOp>(loc, abortFunc, llvm::None);
|
||||||
|
rewriter.create<LLVM::UnreachableOp>(loc);
|
||||||
|
|
||||||
|
// Generate assertion test.
|
||||||
|
rewriter.setInsertionPointToEnd(opBlock);
|
||||||
|
rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
|
||||||
|
op, adaptor.getArg(), continuationBlock, failureBlock);
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Base class for LLVM IR lowering terminator operations with successors.
|
||||||
|
template <typename SourceOp, typename TargetOp>
|
||||||
|
struct OneToOneLLVMTerminatorLowering
|
||||||
|
: public ConvertOpToLLVMPattern<SourceOp> {
|
||||||
|
using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
|
||||||
|
using Base = OneToOneLLVMTerminatorLowering<SourceOp, TargetOp>;
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
rewriter.replaceOpWithNewOp<TargetOp>(op, adaptor.getOperands(),
|
||||||
|
op->getSuccessors(), op->getAttrs());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// FIXME: this should be tablegen'ed as well.
|
||||||
|
struct BranchOpLowering
|
||||||
|
: public OneToOneLLVMTerminatorLowering<cf::BranchOp, LLVM::BrOp> {
|
||||||
|
using Base::Base;
|
||||||
|
};
|
||||||
|
struct CondBranchOpLowering
|
||||||
|
: public OneToOneLLVMTerminatorLowering<cf::CondBranchOp, LLVM::CondBrOp> {
|
||||||
|
using Base::Base;
|
||||||
|
};
|
||||||
|
struct SwitchOpLowering
|
||||||
|
: public OneToOneLLVMTerminatorLowering<cf::SwitchOp, LLVM::SwitchOp> {
|
||||||
|
using Base::Base;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void mlir::cf::populateControlFlowToLLVMConversionPatterns(
|
||||||
|
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
|
||||||
|
// clang-format off
|
||||||
|
patterns.add<
|
||||||
|
AssertOpLowering,
|
||||||
|
BranchOpLowering,
|
||||||
|
CondBranchOpLowering,
|
||||||
|
SwitchOpLowering>(converter);
|
||||||
|
// clang-format on
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Pass Definition
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
/// A pass converting MLIR operations into the LLVM IR dialect.
|
||||||
|
struct ConvertControlFlowToLLVM
|
||||||
|
: public ConvertControlFlowToLLVMBase<ConvertControlFlowToLLVM> {
|
||||||
|
ConvertControlFlowToLLVM() = default;
|
||||||
|
|
||||||
|
/// Run the dialect converter on the module.
|
||||||
|
void runOnOperation() override {
|
||||||
|
LLVMConversionTarget target(getContext());
|
||||||
|
RewritePatternSet patterns(&getContext());
|
||||||
|
|
||||||
|
LowerToLLVMOptions options(&getContext());
|
||||||
|
if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
|
||||||
|
options.overrideIndexBitwidth(indexBitwidth);
|
||||||
|
|
||||||
|
LLVMTypeConverter converter(&getContext(), options);
|
||||||
|
mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
|
||||||
|
|
||||||
|
if (failed(applyPartialConversion(getOperation(), target,
|
||||||
|
std::move(patterns))))
|
||||||
|
signalPassFailure();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<Pass> mlir::cf::createConvertControlFlowToLLVMPass() {
|
||||||
|
return std::make_unique<ConvertControlFlowToLLVM>();
|
||||||
|
}
|
|
@ -0,0 +1,19 @@
|
||||||
|
add_mlir_conversion_library(MLIRControlFlowToSPIRV
|
||||||
|
ControlFlowToSPIRV.cpp
|
||||||
|
|
||||||
|
ADDITIONAL_HEADER_DIRS
|
||||||
|
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SPIRV
|
||||||
|
${MLIR_MAIN_INCLUDE_DIR}/mlir/IR
|
||||||
|
|
||||||
|
DEPENDS
|
||||||
|
MLIRConversionPassIncGen
|
||||||
|
|
||||||
|
LINK_LIBS PUBLIC
|
||||||
|
MLIRIR
|
||||||
|
MLIRControlFlow
|
||||||
|
MLIRPass
|
||||||
|
MLIRSPIRV
|
||||||
|
MLIRSPIRVConversion
|
||||||
|
MLIRSupport
|
||||||
|
MLIRTransformUtils
|
||||||
|
)
|
|
@ -0,0 +1,73 @@
|
||||||
|
//===- ControlFlowToSPIRV.cpp - ControlFlow to SPIR-V Patterns ------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// This file implements patterns to convert standard dialect to SPIR-V dialect.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h"
|
||||||
|
#include "../SPIRVCommon/Pattern.h"
|
||||||
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||||
|
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
|
||||||
|
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
|
||||||
|
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
|
||||||
|
#include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h"
|
||||||
|
#include "mlir/IR/AffineMap.h"
|
||||||
|
#include "mlir/Support/LogicalResult.h"
|
||||||
|
#include "llvm/ADT/SetVector.h"
|
||||||
|
#include "llvm/Support/Debug.h"
|
||||||
|
|
||||||
|
#define DEBUG_TYPE "cf-to-spirv-pattern"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Operation conversion
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
/// Converts cf.br to spv.Branch.
|
||||||
|
struct BranchOpPattern final : public OpConversionPattern<cf::BranchOp> {
|
||||||
|
using OpConversionPattern<cf::BranchOp>::OpConversionPattern;
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(cf::BranchOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
rewriter.replaceOpWithNewOp<spirv::BranchOp>(op, op.getDest(),
|
||||||
|
adaptor.getDestOperands());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Converts cf.cond_br to spv.BranchConditional.
|
||||||
|
struct CondBranchOpPattern final
|
||||||
|
: public OpConversionPattern<cf::CondBranchOp> {
|
||||||
|
using OpConversionPattern<cf::CondBranchOp>::OpConversionPattern;
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(cf::CondBranchOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
rewriter.replaceOpWithNewOp<spirv::BranchConditionalOp>(
|
||||||
|
op, op.getCondition(), op.getTrueDest(), adaptor.getTrueDestOperands(),
|
||||||
|
op.getFalseDest(), adaptor.getFalseDestOperands());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Pattern population
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
void mlir::cf::populateControlFlowToSPIRVPatterns(
|
||||||
|
SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
|
||||||
|
MLIRContext *context = patterns.getContext();
|
||||||
|
|
||||||
|
patterns.add<BranchOpPattern, CondBranchOpPattern>(typeConverter, context);
|
||||||
|
}
|
|
@ -14,6 +14,7 @@
|
||||||
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
|
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
|
||||||
|
|
||||||
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
|
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
|
||||||
|
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
|
||||||
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
|
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
|
||||||
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
|
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
|
||||||
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
|
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
|
||||||
|
@ -172,8 +173,8 @@ struct LowerGpuOpsToNVVMOpsPass
|
||||||
populateGpuRewritePatterns(patterns);
|
populateGpuRewritePatterns(patterns);
|
||||||
(void)applyPatternsAndFoldGreedily(m, std::move(patterns));
|
(void)applyPatternsAndFoldGreedily(m, std::move(patterns));
|
||||||
|
|
||||||
mlir::arith::populateArithmeticToLLVMConversionPatterns(converter,
|
arith::populateArithmeticToLLVMConversionPatterns(converter, llvmPatterns);
|
||||||
llvmPatterns);
|
cf::populateControlFlowToLLVMConversionPatterns(converter, llvmPatterns);
|
||||||
populateStdToLLVMConversionPatterns(converter, llvmPatterns);
|
populateStdToLLVMConversionPatterns(converter, llvmPatterns);
|
||||||
populateMemRefToLLVMConversionPatterns(converter, llvmPatterns);
|
populateMemRefToLLVMConversionPatterns(converter, llvmPatterns);
|
||||||
populateGpuToNVVMConversionPatterns(converter, llvmPatterns);
|
populateGpuToNVVMConversionPatterns(converter, llvmPatterns);
|
||||||
|
|
|
@ -19,7 +19,7 @@ add_mlir_conversion_library(MLIRLinalgToLLVM
|
||||||
MLIRLLVMCommonConversion
|
MLIRLLVMCommonConversion
|
||||||
MLIRLLVMIR
|
MLIRLLVMIR
|
||||||
MLIRMemRefToLLVM
|
MLIRMemRefToLLVM
|
||||||
MLIRSCFToStandard
|
MLIRSCFToControlFlow
|
||||||
MLIRTransforms
|
MLIRTransforms
|
||||||
MLIRVectorToLLVM
|
MLIRVectorToLLVM
|
||||||
MLIRVectorToSCF
|
MLIRVectorToSCF
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
||||||
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
|
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
|
||||||
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
|
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
|
||||||
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
|
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
|
||||||
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
|
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
|
||||||
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
|
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
|
||||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||||
|
|
|
@ -433,7 +433,7 @@ struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
|
||||||
/// +---------------------------------+
|
/// +---------------------------------+
|
||||||
/// | <code before the AtomicRMWOp> |
|
/// | <code before the AtomicRMWOp> |
|
||||||
/// | <compute initial %loaded> |
|
/// | <compute initial %loaded> |
|
||||||
/// | br loop(%loaded) |
|
/// | cf.br loop(%loaded) |
|
||||||
/// +---------------------------------+
|
/// +---------------------------------+
|
||||||
/// |
|
/// |
|
||||||
/// -------| |
|
/// -------| |
|
||||||
|
@ -444,7 +444,7 @@ struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
|
||||||
/// | | %pair = cmpxchg |
|
/// | | %pair = cmpxchg |
|
||||||
/// | | %ok = %pair[0] |
|
/// | | %ok = %pair[0] |
|
||||||
/// | | %new = %pair[1] |
|
/// | | %new = %pair[1] |
|
||||||
/// | | cond_br %ok, end, loop(%new) |
|
/// | | cf.cond_br %ok, end, loop(%new) |
|
||||||
/// | +--------------------------------+
|
/// | +--------------------------------+
|
||||||
/// | | |
|
/// | | |
|
||||||
/// |----------- |
|
/// |----------- |
|
||||||
|
|
|
@ -10,6 +10,7 @@
|
||||||
|
|
||||||
#include "../PassDetail.h"
|
#include "../PassDetail.h"
|
||||||
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
|
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
|
||||||
|
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
|
||||||
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
|
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
|
||||||
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
||||||
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
|
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
|
||||||
|
@ -66,7 +67,8 @@ void ConvertOpenMPToLLVMPass::runOnOperation() {
|
||||||
// Convert to OpenMP operations with LLVM IR dialect
|
// Convert to OpenMP operations with LLVM IR dialect
|
||||||
RewritePatternSet patterns(&getContext());
|
RewritePatternSet patterns(&getContext());
|
||||||
LLVMTypeConverter converter(&getContext());
|
LLVMTypeConverter converter(&getContext());
|
||||||
mlir::arith::populateArithmeticToLLVMConversionPatterns(converter, patterns);
|
arith::populateArithmeticToLLVMConversionPatterns(converter, patterns);
|
||||||
|
cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
|
||||||
populateMemRefToLLVMConversionPatterns(converter, patterns);
|
populateMemRefToLLVMConversionPatterns(converter, patterns);
|
||||||
populateStdToLLVMConversionPatterns(converter, patterns);
|
populateStdToLLVMConversionPatterns(converter, patterns);
|
||||||
populateOpenMPToLLVMConversionPatterns(converter, patterns);
|
populateOpenMPToLLVMConversionPatterns(converter, patterns);
|
||||||
|
|
|
@ -29,6 +29,10 @@ namespace arith {
|
||||||
class ArithmeticDialect;
|
class ArithmeticDialect;
|
||||||
} // namespace arith
|
} // namespace arith
|
||||||
|
|
||||||
|
namespace cf {
|
||||||
|
class ControlFlowDialect;
|
||||||
|
} // namespace cf
|
||||||
|
|
||||||
namespace complex {
|
namespace complex {
|
||||||
class ComplexDialect;
|
class ComplexDialect;
|
||||||
} // namespace complex
|
} // namespace complex
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
add_mlir_conversion_library(MLIRSCFToStandard
|
add_mlir_conversion_library(MLIRSCFToControlFlow
|
||||||
SCFToStandard.cpp
|
SCFToControlFlow.cpp
|
||||||
|
|
||||||
ADDITIONAL_HEADER_DIRS
|
ADDITIONAL_HEADER_DIRS
|
||||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/SCFToStandard
|
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/SCFToControlFlow
|
||||||
|
|
||||||
DEPENDS
|
DEPENDS
|
||||||
MLIRConversionPassIncGen
|
MLIRConversionPassIncGen
|
||||||
|
@ -12,6 +12,7 @@ add_mlir_conversion_library(MLIRSCFToStandard
|
||||||
|
|
||||||
LINK_LIBS PUBLIC
|
LINK_LIBS PUBLIC
|
||||||
MLIRArithmetic
|
MLIRArithmetic
|
||||||
|
MLIRControlFlow
|
||||||
MLIRSCF
|
MLIRSCF
|
||||||
MLIRTransforms
|
MLIRTransforms
|
||||||
)
|
)
|
|
@ -1,4 +1,4 @@
|
||||||
//===- SCFToStandard.cpp - ControlFlow to CFG conversion ------------------===//
|
//===- SCFToControlFlow.cpp - SCF to CF conversion ------------------------===//
|
||||||
//
|
//
|
||||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
// See https://llvm.org/LICENSE.txt for license information.
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
@ -11,11 +11,11 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
|
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
|
||||||
#include "../PassDetail.h"
|
#include "../PassDetail.h"
|
||||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||||
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||||
#include "mlir/Dialect/SCF/SCF.h"
|
#include "mlir/Dialect/SCF/SCF.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
|
||||||
#include "mlir/IR/BlockAndValueMapping.h"
|
#include "mlir/IR/BlockAndValueMapping.h"
|
||||||
#include "mlir/IR/Builders.h"
|
#include "mlir/IR/Builders.h"
|
||||||
#include "mlir/IR/BuiltinOps.h"
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
|
@ -29,7 +29,8 @@ using namespace mlir::scf;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
struct SCFToStandardPass : public SCFToStandardBase<SCFToStandardPass> {
|
struct SCFToControlFlowPass
|
||||||
|
: public SCFToControlFlowBase<SCFToControlFlowPass> {
|
||||||
void runOnOperation() override;
|
void runOnOperation() override;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -57,7 +58,7 @@ struct SCFToStandardPass : public SCFToStandardBase<SCFToStandardPass> {
|
||||||
// | <code before the ForOp> |
|
// | <code before the ForOp> |
|
||||||
// | <definitions of %init...> |
|
// | <definitions of %init...> |
|
||||||
// | <compute initial %iv value> |
|
// | <compute initial %iv value> |
|
||||||
// | br cond(%iv, %init...) |
|
// | cf.br cond(%iv, %init...) |
|
||||||
// +---------------------------------+
|
// +---------------------------------+
|
||||||
// |
|
// |
|
||||||
// -------| |
|
// -------| |
|
||||||
|
@ -65,7 +66,7 @@ struct SCFToStandardPass : public SCFToStandardBase<SCFToStandardPass> {
|
||||||
// | +--------------------------------+
|
// | +--------------------------------+
|
||||||
// | | cond(%iv, %init...): |
|
// | | cond(%iv, %init...): |
|
||||||
// | | <compare %iv to upper bound> |
|
// | | <compare %iv to upper bound> |
|
||||||
// | | cond_br %r, body, end |
|
// | | cf.cond_br %r, body, end |
|
||||||
// | +--------------------------------+
|
// | +--------------------------------+
|
||||||
// | | |
|
// | | |
|
||||||
// | | -------------|
|
// | | -------------|
|
||||||
|
@ -83,7 +84,7 @@ struct SCFToStandardPass : public SCFToStandardBase<SCFToStandardPass> {
|
||||||
// | | <body contents> | |
|
// | | <body contents> | |
|
||||||
// | | <operands of yield = %yields>| |
|
// | | <operands of yield = %yields>| |
|
||||||
// | | %new_iv =<add step to %iv> | |
|
// | | %new_iv =<add step to %iv> | |
|
||||||
// | | br cond(%new_iv, %yields) | |
|
// | | cf.br cond(%new_iv, %yields) | |
|
||||||
// | +--------------------------------+ |
|
// | +--------------------------------+ |
|
||||||
// | | |
|
// | | |
|
||||||
// |----------- |--------------------
|
// |----------- |--------------------
|
||||||
|
@ -125,7 +126,7 @@ struct ForLowering : public OpRewritePattern<ForOp> {
|
||||||
//
|
//
|
||||||
// +--------------------------------+
|
// +--------------------------------+
|
||||||
// | <code before the IfOp> |
|
// | <code before the IfOp> |
|
||||||
// | cond_br %cond, %then, %else |
|
// | cf.cond_br %cond, %then, %else |
|
||||||
// +--------------------------------+
|
// +--------------------------------+
|
||||||
// | |
|
// | |
|
||||||
// | --------------|
|
// | --------------|
|
||||||
|
@ -133,7 +134,7 @@ struct ForLowering : public OpRewritePattern<ForOp> {
|
||||||
// +--------------------------------+ |
|
// +--------------------------------+ |
|
||||||
// | then: | |
|
// | then: | |
|
||||||
// | <then contents> | |
|
// | <then contents> | |
|
||||||
// | br continue | |
|
// | cf.br continue | |
|
||||||
// +--------------------------------+ |
|
// +--------------------------------+ |
|
||||||
// | |
|
// | |
|
||||||
// |---------- |-------------
|
// |---------- |-------------
|
||||||
|
@ -141,7 +142,7 @@ struct ForLowering : public OpRewritePattern<ForOp> {
|
||||||
// | +--------------------------------+
|
// | +--------------------------------+
|
||||||
// | | else: |
|
// | | else: |
|
||||||
// | | <else contents> |
|
// | | <else contents> |
|
||||||
// | | br continue |
|
// | | cf.br continue |
|
||||||
// | +--------------------------------+
|
// | +--------------------------------+
|
||||||
// | |
|
// | |
|
||||||
// ------| |
|
// ------| |
|
||||||
|
@ -155,7 +156,7 @@ struct ForLowering : public OpRewritePattern<ForOp> {
|
||||||
//
|
//
|
||||||
// +--------------------------------+
|
// +--------------------------------+
|
||||||
// | <code before the IfOp> |
|
// | <code before the IfOp> |
|
||||||
// | cond_br %cond, %then, %else |
|
// | cf.cond_br %cond, %then, %else |
|
||||||
// +--------------------------------+
|
// +--------------------------------+
|
||||||
// | |
|
// | |
|
||||||
// | --------------|
|
// | --------------|
|
||||||
|
@ -163,7 +164,7 @@ struct ForLowering : public OpRewritePattern<ForOp> {
|
||||||
// +--------------------------------+ |
|
// +--------------------------------+ |
|
||||||
// | then: | |
|
// | then: | |
|
||||||
// | <then contents> | |
|
// | <then contents> | |
|
||||||
// | br dom(%args...) | |
|
// | cf.br dom(%args...) | |
|
||||||
// +--------------------------------+ |
|
// +--------------------------------+ |
|
||||||
// | |
|
// | |
|
||||||
// |---------- |-------------
|
// |---------- |-------------
|
||||||
|
@ -171,14 +172,14 @@ struct ForLowering : public OpRewritePattern<ForOp> {
|
||||||
// | +--------------------------------+
|
// | +--------------------------------+
|
||||||
// | | else: |
|
// | | else: |
|
||||||
// | | <else contents> |
|
// | | <else contents> |
|
||||||
// | | br dom(%args...) |
|
// | | cf.br dom(%args...) |
|
||||||
// | +--------------------------------+
|
// | +--------------------------------+
|
||||||
// | |
|
// | |
|
||||||
// ------| |
|
// ------| |
|
||||||
// v v
|
// v v
|
||||||
// +--------------------------------+
|
// +--------------------------------+
|
||||||
// | dom(%args...): |
|
// | dom(%args...): |
|
||||||
// | br continue |
|
// | cf.br continue |
|
||||||
// +--------------------------------+
|
// +--------------------------------+
|
||||||
// |
|
// |
|
||||||
// v
|
// v
|
||||||
|
@ -218,7 +219,7 @@ struct ParallelLowering : public OpRewritePattern<mlir::scf::ParallelOp> {
|
||||||
///
|
///
|
||||||
/// +---------------------------------+
|
/// +---------------------------------+
|
||||||
/// | <code before the WhileOp> |
|
/// | <code before the WhileOp> |
|
||||||
/// | br ^before(%operands...) |
|
/// | cf.br ^before(%operands...) |
|
||||||
/// +---------------------------------+
|
/// +---------------------------------+
|
||||||
/// |
|
/// |
|
||||||
/// -------| |
|
/// -------| |
|
||||||
|
@ -233,7 +234,7 @@ struct ParallelLowering : public OpRewritePattern<mlir::scf::ParallelOp> {
|
||||||
/// | +--------------------------------+
|
/// | +--------------------------------+
|
||||||
/// | | ^before-last:
|
/// | | ^before-last:
|
||||||
/// | | %cond = <compute condition> |
|
/// | | %cond = <compute condition> |
|
||||||
/// | | cond_br %cond, |
|
/// | | cf.cond_br %cond, |
|
||||||
/// | | ^after(%vals...), ^cont |
|
/// | | ^after(%vals...), ^cont |
|
||||||
/// | +--------------------------------+
|
/// | +--------------------------------+
|
||||||
/// | | |
|
/// | | |
|
||||||
|
@ -249,7 +250,7 @@ struct ParallelLowering : public OpRewritePattern<mlir::scf::ParallelOp> {
|
||||||
/// | +--------------------------------+ |
|
/// | +--------------------------------+ |
|
||||||
/// | | ^after-last: | |
|
/// | | ^after-last: | |
|
||||||
/// | | %yields... = <some payload> | |
|
/// | | %yields... = <some payload> | |
|
||||||
/// | | br ^before(%yields...) | |
|
/// | | cf.br ^before(%yields...) | |
|
||||||
/// | +--------------------------------+ |
|
/// | +--------------------------------+ |
|
||||||
/// | | |
|
/// | | |
|
||||||
/// |----------- |--------------------
|
/// |----------- |--------------------
|
||||||
|
@ -321,7 +322,7 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
|
||||||
SmallVector<Value, 8> loopCarried;
|
SmallVector<Value, 8> loopCarried;
|
||||||
loopCarried.push_back(stepped);
|
loopCarried.push_back(stepped);
|
||||||
loopCarried.append(terminator->operand_begin(), terminator->operand_end());
|
loopCarried.append(terminator->operand_begin(), terminator->operand_end());
|
||||||
rewriter.create<BranchOp>(loc, conditionBlock, loopCarried);
|
rewriter.create<cf::BranchOp>(loc, conditionBlock, loopCarried);
|
||||||
rewriter.eraseOp(terminator);
|
rewriter.eraseOp(terminator);
|
||||||
|
|
||||||
// Compute loop bounds before branching to the condition.
|
// Compute loop bounds before branching to the condition.
|
||||||
|
@ -337,15 +338,16 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
|
||||||
destOperands.push_back(lowerBound);
|
destOperands.push_back(lowerBound);
|
||||||
auto iterOperands = forOp.getIterOperands();
|
auto iterOperands = forOp.getIterOperands();
|
||||||
destOperands.append(iterOperands.begin(), iterOperands.end());
|
destOperands.append(iterOperands.begin(), iterOperands.end());
|
||||||
rewriter.create<BranchOp>(loc, conditionBlock, destOperands);
|
rewriter.create<cf::BranchOp>(loc, conditionBlock, destOperands);
|
||||||
|
|
||||||
// With the body block done, we can fill in the condition block.
|
// With the body block done, we can fill in the condition block.
|
||||||
rewriter.setInsertionPointToEnd(conditionBlock);
|
rewriter.setInsertionPointToEnd(conditionBlock);
|
||||||
auto comparison = rewriter.create<arith::CmpIOp>(
|
auto comparison = rewriter.create<arith::CmpIOp>(
|
||||||
loc, arith::CmpIPredicate::slt, iv, upperBound);
|
loc, arith::CmpIPredicate::slt, iv, upperBound);
|
||||||
|
|
||||||
rewriter.create<CondBranchOp>(loc, comparison, firstBodyBlock,
|
rewriter.create<cf::CondBranchOp>(loc, comparison, firstBodyBlock,
|
||||||
ArrayRef<Value>(), endBlock, ArrayRef<Value>());
|
ArrayRef<Value>(), endBlock,
|
||||||
|
ArrayRef<Value>());
|
||||||
// The result of the loop operation is the values of the condition block
|
// The result of the loop operation is the values of the condition block
|
||||||
// arguments except the induction variable on the last iteration.
|
// arguments except the induction variable on the last iteration.
|
||||||
rewriter.replaceOp(forOp, conditionBlock->getArguments().drop_front());
|
rewriter.replaceOp(forOp, conditionBlock->getArguments().drop_front());
|
||||||
|
@ -369,7 +371,7 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
|
||||||
continueBlock =
|
continueBlock =
|
||||||
rewriter.createBlock(remainingOpsBlock, ifOp.getResultTypes(),
|
rewriter.createBlock(remainingOpsBlock, ifOp.getResultTypes(),
|
||||||
SmallVector<Location>(ifOp.getNumResults(), loc));
|
SmallVector<Location>(ifOp.getNumResults(), loc));
|
||||||
rewriter.create<BranchOp>(loc, remainingOpsBlock);
|
rewriter.create<cf::BranchOp>(loc, remainingOpsBlock);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Move blocks from the "then" region to the region containing 'scf.if',
|
// Move blocks from the "then" region to the region containing 'scf.if',
|
||||||
|
@ -379,7 +381,7 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
|
||||||
Operation *thenTerminator = thenRegion.back().getTerminator();
|
Operation *thenTerminator = thenRegion.back().getTerminator();
|
||||||
ValueRange thenTerminatorOperands = thenTerminator->getOperands();
|
ValueRange thenTerminatorOperands = thenTerminator->getOperands();
|
||||||
rewriter.setInsertionPointToEnd(&thenRegion.back());
|
rewriter.setInsertionPointToEnd(&thenRegion.back());
|
||||||
rewriter.create<BranchOp>(loc, continueBlock, thenTerminatorOperands);
|
rewriter.create<cf::BranchOp>(loc, continueBlock, thenTerminatorOperands);
|
||||||
rewriter.eraseOp(thenTerminator);
|
rewriter.eraseOp(thenTerminator);
|
||||||
rewriter.inlineRegionBefore(thenRegion, continueBlock);
|
rewriter.inlineRegionBefore(thenRegion, continueBlock);
|
||||||
|
|
||||||
|
@ -393,15 +395,15 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
|
||||||
Operation *elseTerminator = elseRegion.back().getTerminator();
|
Operation *elseTerminator = elseRegion.back().getTerminator();
|
||||||
ValueRange elseTerminatorOperands = elseTerminator->getOperands();
|
ValueRange elseTerminatorOperands = elseTerminator->getOperands();
|
||||||
rewriter.setInsertionPointToEnd(&elseRegion.back());
|
rewriter.setInsertionPointToEnd(&elseRegion.back());
|
||||||
rewriter.create<BranchOp>(loc, continueBlock, elseTerminatorOperands);
|
rewriter.create<cf::BranchOp>(loc, continueBlock, elseTerminatorOperands);
|
||||||
rewriter.eraseOp(elseTerminator);
|
rewriter.eraseOp(elseTerminator);
|
||||||
rewriter.inlineRegionBefore(elseRegion, continueBlock);
|
rewriter.inlineRegionBefore(elseRegion, continueBlock);
|
||||||
}
|
}
|
||||||
|
|
||||||
rewriter.setInsertionPointToEnd(condBlock);
|
rewriter.setInsertionPointToEnd(condBlock);
|
||||||
rewriter.create<CondBranchOp>(loc, ifOp.getCondition(), thenBlock,
|
rewriter.create<cf::CondBranchOp>(loc, ifOp.getCondition(), thenBlock,
|
||||||
/*trueArgs=*/ArrayRef<Value>(), elseBlock,
|
/*trueArgs=*/ArrayRef<Value>(), elseBlock,
|
||||||
/*falseArgs=*/ArrayRef<Value>());
|
/*falseArgs=*/ArrayRef<Value>());
|
||||||
|
|
||||||
// Ok, we're done!
|
// Ok, we're done!
|
||||||
rewriter.replaceOp(ifOp, continueBlock->getArguments());
|
rewriter.replaceOp(ifOp, continueBlock->getArguments());
|
||||||
|
@ -419,13 +421,13 @@ ExecuteRegionLowering::matchAndRewrite(ExecuteRegionOp op,
|
||||||
|
|
||||||
auto ®ion = op.getRegion();
|
auto ®ion = op.getRegion();
|
||||||
rewriter.setInsertionPointToEnd(condBlock);
|
rewriter.setInsertionPointToEnd(condBlock);
|
||||||
rewriter.create<BranchOp>(loc, ®ion.front());
|
rewriter.create<cf::BranchOp>(loc, ®ion.front());
|
||||||
|
|
||||||
for (Block &block : region) {
|
for (Block &block : region) {
|
||||||
if (auto terminator = dyn_cast<scf::YieldOp>(block.getTerminator())) {
|
if (auto terminator = dyn_cast<scf::YieldOp>(block.getTerminator())) {
|
||||||
ValueRange terminatorOperands = terminator->getOperands();
|
ValueRange terminatorOperands = terminator->getOperands();
|
||||||
rewriter.setInsertionPointToEnd(&block);
|
rewriter.setInsertionPointToEnd(&block);
|
||||||
rewriter.create<BranchOp>(loc, remainingOpsBlock, terminatorOperands);
|
rewriter.create<cf::BranchOp>(loc, remainingOpsBlock, terminatorOperands);
|
||||||
rewriter.eraseOp(terminator);
|
rewriter.eraseOp(terminator);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -538,20 +540,21 @@ LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp,
|
||||||
|
|
||||||
// Branch to the "before" region.
|
// Branch to the "before" region.
|
||||||
rewriter.setInsertionPointToEnd(currentBlock);
|
rewriter.setInsertionPointToEnd(currentBlock);
|
||||||
rewriter.create<BranchOp>(loc, before, whileOp.getInits());
|
rewriter.create<cf::BranchOp>(loc, before, whileOp.getInits());
|
||||||
|
|
||||||
// Replace terminators with branches. Assuming bodies are SESE, which holds
|
// Replace terminators with branches. Assuming bodies are SESE, which holds
|
||||||
// given only the patterns from this file, we only need to look at the last
|
// given only the patterns from this file, we only need to look at the last
|
||||||
// block. This should be reconsidered if we allow break/continue in SCF.
|
// block. This should be reconsidered if we allow break/continue in SCF.
|
||||||
rewriter.setInsertionPointToEnd(beforeLast);
|
rewriter.setInsertionPointToEnd(beforeLast);
|
||||||
auto condOp = cast<ConditionOp>(beforeLast->getTerminator());
|
auto condOp = cast<ConditionOp>(beforeLast->getTerminator());
|
||||||
rewriter.replaceOpWithNewOp<CondBranchOp>(condOp, condOp.getCondition(),
|
rewriter.replaceOpWithNewOp<cf::CondBranchOp>(condOp, condOp.getCondition(),
|
||||||
after, condOp.getArgs(),
|
after, condOp.getArgs(),
|
||||||
continuation, ValueRange());
|
continuation, ValueRange());
|
||||||
|
|
||||||
rewriter.setInsertionPointToEnd(afterLast);
|
rewriter.setInsertionPointToEnd(afterLast);
|
||||||
auto yieldOp = cast<scf::YieldOp>(afterLast->getTerminator());
|
auto yieldOp = cast<scf::YieldOp>(afterLast->getTerminator());
|
||||||
rewriter.replaceOpWithNewOp<BranchOp>(yieldOp, before, yieldOp.getResults());
|
rewriter.replaceOpWithNewOp<cf::BranchOp>(yieldOp, before,
|
||||||
|
yieldOp.getResults());
|
||||||
|
|
||||||
// Replace the op with values "yielded" from the "before" region, which are
|
// Replace the op with values "yielded" from the "before" region, which are
|
||||||
// visible by dominance.
|
// visible by dominance.
|
||||||
|
@ -593,14 +596,14 @@ DoWhileLowering::matchAndRewrite(WhileOp whileOp,
|
||||||
|
|
||||||
// Branch to the "before" region.
|
// Branch to the "before" region.
|
||||||
rewriter.setInsertionPointToEnd(currentBlock);
|
rewriter.setInsertionPointToEnd(currentBlock);
|
||||||
rewriter.create<BranchOp>(whileOp.getLoc(), before, whileOp.getInits());
|
rewriter.create<cf::BranchOp>(whileOp.getLoc(), before, whileOp.getInits());
|
||||||
|
|
||||||
// Loop around the "before" region based on condition.
|
// Loop around the "before" region based on condition.
|
||||||
rewriter.setInsertionPointToEnd(beforeLast);
|
rewriter.setInsertionPointToEnd(beforeLast);
|
||||||
auto condOp = cast<ConditionOp>(beforeLast->getTerminator());
|
auto condOp = cast<ConditionOp>(beforeLast->getTerminator());
|
||||||
rewriter.replaceOpWithNewOp<CondBranchOp>(condOp, condOp.getCondition(),
|
rewriter.replaceOpWithNewOp<cf::CondBranchOp>(condOp, condOp.getCondition(),
|
||||||
before, condOp.getArgs(),
|
before, condOp.getArgs(),
|
||||||
continuation, ValueRange());
|
continuation, ValueRange());
|
||||||
|
|
||||||
// Replace the op with values "yielded" from the "before" region, which are
|
// Replace the op with values "yielded" from the "before" region, which are
|
||||||
// visible by dominance.
|
// visible by dominance.
|
||||||
|
@ -609,17 +612,18 @@ DoWhileLowering::matchAndRewrite(WhileOp whileOp,
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
void mlir::populateLoopToStdConversionPatterns(RewritePatternSet &patterns) {
|
void mlir::populateSCFToControlFlowConversionPatterns(
|
||||||
|
RewritePatternSet &patterns) {
|
||||||
patterns.add<ForLowering, IfLowering, ParallelLowering, WhileLowering,
|
patterns.add<ForLowering, IfLowering, ParallelLowering, WhileLowering,
|
||||||
ExecuteRegionLowering>(patterns.getContext());
|
ExecuteRegionLowering>(patterns.getContext());
|
||||||
patterns.add<DoWhileLowering>(patterns.getContext(), /*benefit=*/2);
|
patterns.add<DoWhileLowering>(patterns.getContext(), /*benefit=*/2);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SCFToStandardPass::runOnOperation() {
|
void SCFToControlFlowPass::runOnOperation() {
|
||||||
RewritePatternSet patterns(&getContext());
|
RewritePatternSet patterns(&getContext());
|
||||||
populateLoopToStdConversionPatterns(patterns);
|
populateSCFToControlFlowConversionPatterns(patterns);
|
||||||
// Configure conversion to lower out scf.for, scf.if, scf.parallel and
|
|
||||||
// scf.while. Anything else is fine.
|
// Configure conversion to lower out SCF operations.
|
||||||
ConversionTarget target(getContext());
|
ConversionTarget target(getContext());
|
||||||
target.addIllegalOp<scf::ForOp, scf::IfOp, scf::ParallelOp, scf::WhileOp,
|
target.addIllegalOp<scf::ForOp, scf::IfOp, scf::ParallelOp, scf::WhileOp,
|
||||||
scf::ExecuteRegionOp>();
|
scf::ExecuteRegionOp>();
|
||||||
|
@ -629,6 +633,6 @@ void SCFToStandardPass::runOnOperation() {
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<Pass> mlir::createLowerToCFGPass() {
|
std::unique_ptr<Pass> mlir::createConvertSCFToCFPass() {
|
||||||
return std::make_unique<SCFToStandardPass>();
|
return std::make_unique<SCFToControlFlowPass>();
|
||||||
}
|
}
|
|
@ -9,6 +9,7 @@
|
||||||
#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
|
#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
|
||||||
|
|
||||||
#include "../PassDetail.h"
|
#include "../PassDetail.h"
|
||||||
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||||
#include "mlir/Dialect/SCF/SCF.h"
|
#include "mlir/Dialect/SCF/SCF.h"
|
||||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
@ -29,7 +30,7 @@ public:
|
||||||
using OpRewritePattern::OpRewritePattern;
|
using OpRewritePattern::OpRewritePattern;
|
||||||
LogicalResult matchAndRewrite(shape::CstrRequireOp op,
|
LogicalResult matchAndRewrite(shape::CstrRequireOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
rewriter.create<AssertOp>(op.getLoc(), op.getPred(), op.getMsgAttr());
|
rewriter.create<cf::AssertOp>(op.getLoc(), op.getPred(), op.getMsgAttr());
|
||||||
rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true);
|
rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,6 +14,7 @@ add_mlir_conversion_library(MLIRStandardToLLVM
|
||||||
LINK_LIBS PUBLIC
|
LINK_LIBS PUBLIC
|
||||||
MLIRAnalysis
|
MLIRAnalysis
|
||||||
MLIRArithmeticToLLVM
|
MLIRArithmeticToLLVM
|
||||||
|
MLIRControlFlowToLLVM
|
||||||
MLIRDataLayoutInterfaces
|
MLIRDataLayoutInterfaces
|
||||||
MLIRLLVMCommonConversion
|
MLIRLLVMCommonConversion
|
||||||
MLIRLLVMIR
|
MLIRLLVMIR
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
#include "../PassDetail.h"
|
#include "../PassDetail.h"
|
||||||
#include "mlir/Analysis/DataLayoutAnalysis.h"
|
#include "mlir/Analysis/DataLayoutAnalysis.h"
|
||||||
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
|
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
|
||||||
|
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
|
||||||
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
|
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
|
||||||
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
||||||
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
|
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
|
||||||
|
@ -387,48 +388,6 @@ struct BarePtrFuncOpConversion : public FuncOpConversionBase {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Lower `std.assert`. The default lowering calls the `abort` function if the
|
|
||||||
/// assertion is violated and has no effect otherwise. The failure message is
|
|
||||||
/// ignored by the default lowering but should be propagated by any custom
|
|
||||||
/// lowering.
|
|
||||||
struct AssertOpLowering : public ConvertOpToLLVMPattern<AssertOp> {
|
|
||||||
using ConvertOpToLLVMPattern<AssertOp>::ConvertOpToLLVMPattern;
|
|
||||||
|
|
||||||
LogicalResult
|
|
||||||
matchAndRewrite(AssertOp op, OpAdaptor adaptor,
|
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
|
||||||
auto loc = op.getLoc();
|
|
||||||
|
|
||||||
// Insert the `abort` declaration if necessary.
|
|
||||||
auto module = op->getParentOfType<ModuleOp>();
|
|
||||||
auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
|
|
||||||
if (!abortFunc) {
|
|
||||||
OpBuilder::InsertionGuard guard(rewriter);
|
|
||||||
rewriter.setInsertionPointToStart(module.getBody());
|
|
||||||
auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {});
|
|
||||||
abortFunc = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(),
|
|
||||||
"abort", abortFuncTy);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Split block at `assert` operation.
|
|
||||||
Block *opBlock = rewriter.getInsertionBlock();
|
|
||||||
auto opPosition = rewriter.getInsertionPoint();
|
|
||||||
Block *continuationBlock = rewriter.splitBlock(opBlock, opPosition);
|
|
||||||
|
|
||||||
// Generate IR to call `abort`.
|
|
||||||
Block *failureBlock = rewriter.createBlock(opBlock->getParent());
|
|
||||||
rewriter.create<LLVM::CallOp>(loc, abortFunc, llvm::None);
|
|
||||||
rewriter.create<LLVM::UnreachableOp>(loc);
|
|
||||||
|
|
||||||
// Generate assertion test.
|
|
||||||
rewriter.setInsertionPointToEnd(opBlock);
|
|
||||||
rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
|
|
||||||
op, adaptor.getArg(), continuationBlock, failureBlock);
|
|
||||||
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct ConstantOpLowering : public ConvertOpToLLVMPattern<ConstantOp> {
|
struct ConstantOpLowering : public ConvertOpToLLVMPattern<ConstantOp> {
|
||||||
using ConvertOpToLLVMPattern<ConstantOp>::ConvertOpToLLVMPattern;
|
using ConvertOpToLLVMPattern<ConstantOp>::ConvertOpToLLVMPattern;
|
||||||
|
|
||||||
|
@ -550,22 +509,6 @@ struct UnrealizedConversionCastOpLowering
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Base class for LLVM IR lowering terminator operations with successors.
|
|
||||||
template <typename SourceOp, typename TargetOp>
|
|
||||||
struct OneToOneLLVMTerminatorLowering
|
|
||||||
: public ConvertOpToLLVMPattern<SourceOp> {
|
|
||||||
using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
|
|
||||||
using Super = OneToOneLLVMTerminatorLowering<SourceOp, TargetOp>;
|
|
||||||
|
|
||||||
LogicalResult
|
|
||||||
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
|
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
|
||||||
rewriter.replaceOpWithNewOp<TargetOp>(op, adaptor.getOperands(),
|
|
||||||
op->getSuccessors(), op->getAttrs());
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Special lowering pattern for `ReturnOps`. Unlike all other operations,
|
// Special lowering pattern for `ReturnOps`. Unlike all other operations,
|
||||||
// `ReturnOp` interacts with the function signature and must have as many
|
// `ReturnOp` interacts with the function signature and must have as many
|
||||||
// operands as the function has return values. Because in LLVM IR, functions
|
// operands as the function has return values. Because in LLVM IR, functions
|
||||||
|
@ -633,21 +576,6 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<ReturnOp> {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// FIXME: this should be tablegen'ed as well.
|
|
||||||
struct BranchOpLowering
|
|
||||||
: public OneToOneLLVMTerminatorLowering<BranchOp, LLVM::BrOp> {
|
|
||||||
using Super::Super;
|
|
||||||
};
|
|
||||||
struct CondBranchOpLowering
|
|
||||||
: public OneToOneLLVMTerminatorLowering<CondBranchOp, LLVM::CondBrOp> {
|
|
||||||
using Super::Super;
|
|
||||||
};
|
|
||||||
struct SwitchOpLowering
|
|
||||||
: public OneToOneLLVMTerminatorLowering<SwitchOp, LLVM::SwitchOp> {
|
|
||||||
using Super::Super;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void mlir::populateStdToLLVMFuncOpConversionPattern(
|
void mlir::populateStdToLLVMFuncOpConversionPattern(
|
||||||
|
@ -663,14 +591,10 @@ void mlir::populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter,
|
||||||
populateStdToLLVMFuncOpConversionPattern(converter, patterns);
|
populateStdToLLVMFuncOpConversionPattern(converter, patterns);
|
||||||
// clang-format off
|
// clang-format off
|
||||||
patterns.add<
|
patterns.add<
|
||||||
AssertOpLowering,
|
|
||||||
BranchOpLowering,
|
|
||||||
CallIndirectOpLowering,
|
CallIndirectOpLowering,
|
||||||
CallOpLowering,
|
CallOpLowering,
|
||||||
CondBranchOpLowering,
|
|
||||||
ConstantOpLowering,
|
ConstantOpLowering,
|
||||||
ReturnOpLowering,
|
ReturnOpLowering>(converter);
|
||||||
SwitchOpLowering>(converter);
|
|
||||||
// clang-format on
|
// clang-format on
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -721,6 +645,7 @@ struct LLVMLoweringPass : public ConvertStandardToLLVMBase<LLVMLoweringPass> {
|
||||||
RewritePatternSet patterns(&getContext());
|
RewritePatternSet patterns(&getContext());
|
||||||
populateStdToLLVMConversionPatterns(typeConverter, patterns);
|
populateStdToLLVMConversionPatterns(typeConverter, patterns);
|
||||||
arith::populateArithmeticToLLVMConversionPatterns(typeConverter, patterns);
|
arith::populateArithmeticToLLVMConversionPatterns(typeConverter, patterns);
|
||||||
|
cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns);
|
||||||
|
|
||||||
LLVMConversionTarget target(getContext());
|
LLVMConversionTarget target(getContext());
|
||||||
if (failed(applyPartialConversion(m, target, std::move(patterns))))
|
if (failed(applyPartialConversion(m, target, std::move(patterns))))
|
||||||
|
|
|
@ -11,6 +11,7 @@ add_mlir_conversion_library(MLIRStandardToSPIRV
|
||||||
|
|
||||||
LINK_LIBS PUBLIC
|
LINK_LIBS PUBLIC
|
||||||
MLIRArithmeticToSPIRV
|
MLIRArithmeticToSPIRV
|
||||||
|
MLIRControlFlowToSPIRV
|
||||||
MLIRIR
|
MLIRIR
|
||||||
MLIRMathToSPIRV
|
MLIRMathToSPIRV
|
||||||
MLIRMemRef
|
MLIRMemRef
|
||||||
|
|
|
@ -46,24 +46,6 @@ public:
|
||||||
ConversionPatternRewriter &rewriter) const override;
|
ConversionPatternRewriter &rewriter) const override;
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Converts std.br to spv.Branch.
|
|
||||||
struct BranchOpPattern final : public OpConversionPattern<BranchOp> {
|
|
||||||
using OpConversionPattern<BranchOp>::OpConversionPattern;
|
|
||||||
|
|
||||||
LogicalResult
|
|
||||||
matchAndRewrite(BranchOp op, OpAdaptor adaptor,
|
|
||||||
ConversionPatternRewriter &rewriter) const override;
|
|
||||||
};
|
|
||||||
|
|
||||||
/// Converts std.cond_br to spv.BranchConditional.
|
|
||||||
struct CondBranchOpPattern final : public OpConversionPattern<CondBranchOp> {
|
|
||||||
using OpConversionPattern<CondBranchOp>::OpConversionPattern;
|
|
||||||
|
|
||||||
LogicalResult
|
|
||||||
matchAndRewrite(CondBranchOp op, OpAdaptor adaptor,
|
|
||||||
ConversionPatternRewriter &rewriter) const override;
|
|
||||||
};
|
|
||||||
|
|
||||||
/// Converts tensor.extract into loading using access chains from SPIR-V local
|
/// Converts tensor.extract into loading using access chains from SPIR-V local
|
||||||
/// variables.
|
/// variables.
|
||||||
class TensorExtractPattern final
|
class TensorExtractPattern final
|
||||||
|
@ -146,31 +128,6 @@ ReturnOpPattern::matchAndRewrite(ReturnOp returnOp, OpAdaptor adaptor,
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// BranchOpPattern
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
LogicalResult
|
|
||||||
BranchOpPattern::matchAndRewrite(BranchOp op, OpAdaptor adaptor,
|
|
||||||
ConversionPatternRewriter &rewriter) const {
|
|
||||||
rewriter.replaceOpWithNewOp<spirv::BranchOp>(op, op.getDest(),
|
|
||||||
adaptor.getDestOperands());
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// CondBranchOpPattern
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
LogicalResult CondBranchOpPattern::matchAndRewrite(
|
|
||||||
CondBranchOp op, OpAdaptor adaptor,
|
|
||||||
ConversionPatternRewriter &rewriter) const {
|
|
||||||
rewriter.replaceOpWithNewOp<spirv::BranchConditionalOp>(
|
|
||||||
op, op.getCondition(), op.getTrueDest(), adaptor.getTrueDestOperands(),
|
|
||||||
op.getFalseDest(), adaptor.getFalseDestOperands());
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Pattern population
|
// Pattern population
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -189,8 +146,7 @@ void populateStandardToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
|
||||||
spirv::ElementwiseOpPattern<arith::MinSIOp, spirv::GLSLSMinOp>,
|
spirv::ElementwiseOpPattern<arith::MinSIOp, spirv::GLSLSMinOp>,
|
||||||
spirv::ElementwiseOpPattern<arith::MinUIOp, spirv::GLSLUMinOp>,
|
spirv::ElementwiseOpPattern<arith::MinUIOp, spirv::GLSLUMinOp>,
|
||||||
|
|
||||||
ReturnOpPattern, BranchOpPattern, CondBranchOpPattern>(typeConverter,
|
ReturnOpPattern>(typeConverter, context);
|
||||||
context);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void populateTensorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
|
void populateTensorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
#include "mlir/Conversion/StandardToSPIRV/StandardToSPIRVPass.h"
|
#include "mlir/Conversion/StandardToSPIRV/StandardToSPIRVPass.h"
|
||||||
#include "../PassDetail.h"
|
#include "../PassDetail.h"
|
||||||
#include "mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h"
|
#include "mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h"
|
||||||
|
#include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h"
|
||||||
#include "mlir/Conversion/MathToSPIRV/MathToSPIRV.h"
|
#include "mlir/Conversion/MathToSPIRV/MathToSPIRV.h"
|
||||||
#include "mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h"
|
#include "mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h"
|
||||||
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
|
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
|
||||||
|
@ -40,9 +41,11 @@ void ConvertStandardToSPIRVPass::runOnOperation() {
|
||||||
options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes;
|
options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes;
|
||||||
SPIRVTypeConverter typeConverter(targetAttr, options);
|
SPIRVTypeConverter typeConverter(targetAttr, options);
|
||||||
|
|
||||||
// TODO ArithmeticToSPIRV cannot be applied separately to StandardToSPIRV
|
// TODO ArithmeticToSPIRV/ControlFlowToSPIRV cannot be applied separately to
|
||||||
|
// StandardToSPIRV
|
||||||
RewritePatternSet patterns(context);
|
RewritePatternSet patterns(context);
|
||||||
arith::populateArithmeticToSPIRVPatterns(typeConverter, patterns);
|
arith::populateArithmeticToSPIRVPatterns(typeConverter, patterns);
|
||||||
|
cf::populateControlFlowToSPIRVPatterns(typeConverter, patterns);
|
||||||
populateMathToSPIRVPatterns(typeConverter, patterns);
|
populateMathToSPIRVPatterns(typeConverter, patterns);
|
||||||
populateStandardToSPIRVPatterns(typeConverter, patterns);
|
populateStandardToSPIRVPatterns(typeConverter, patterns);
|
||||||
populateTensorToSPIRVPatterns(typeConverter, /*byteCountThreshold=*/64,
|
populateTensorToSPIRVPatterns(typeConverter, /*byteCountThreshold=*/64,
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
#include "mlir/Analysis/Liveness.h"
|
#include "mlir/Analysis/Liveness.h"
|
||||||
#include "mlir/Dialect/Async/IR/Async.h"
|
#include "mlir/Dialect/Async/IR/Async.h"
|
||||||
#include "mlir/Dialect/Async/Passes.h"
|
#include "mlir/Dialect/Async/Passes.h"
|
||||||
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
#include "mlir/IR/ImplicitLocOpBuilder.h"
|
#include "mlir/IR/ImplicitLocOpBuilder.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
@ -169,11 +170,11 @@ private:
|
||||||
///
|
///
|
||||||
/// ^entry:
|
/// ^entry:
|
||||||
/// %token = async.runtime.create : !async.token
|
/// %token = async.runtime.create : !async.token
|
||||||
/// cond_br %cond, ^bb1, ^bb2
|
/// cf.cond_br %cond, ^bb1, ^bb2
|
||||||
/// ^bb1:
|
/// ^bb1:
|
||||||
/// async.runtime.await %token
|
/// async.runtime.await %token
|
||||||
/// async.runtime.drop_ref %token
|
/// async.runtime.drop_ref %token
|
||||||
/// br ^bb2
|
/// cf.br ^bb2
|
||||||
/// ^bb2:
|
/// ^bb2:
|
||||||
/// return
|
/// return
|
||||||
///
|
///
|
||||||
|
@ -185,14 +186,14 @@ private:
|
||||||
///
|
///
|
||||||
/// ^entry:
|
/// ^entry:
|
||||||
/// %token = async.runtime.create : !async.token
|
/// %token = async.runtime.create : !async.token
|
||||||
/// cond_br %cond, ^bb1, ^reference_counting
|
/// cf.cond_br %cond, ^bb1, ^reference_counting
|
||||||
/// ^bb1:
|
/// ^bb1:
|
||||||
/// async.runtime.await %token
|
/// async.runtime.await %token
|
||||||
/// async.runtime.drop_ref %token
|
/// async.runtime.drop_ref %token
|
||||||
/// br ^bb2
|
/// cf.br ^bb2
|
||||||
/// ^reference_counting:
|
/// ^reference_counting:
|
||||||
/// async.runtime.drop_ref %token
|
/// async.runtime.drop_ref %token
|
||||||
/// br ^bb2
|
/// cf.br ^bb2
|
||||||
/// ^bb2:
|
/// ^bb2:
|
||||||
/// return
|
/// return
|
||||||
///
|
///
|
||||||
|
@ -208,7 +209,7 @@ private:
|
||||||
/// async.coro.suspend %ret, ^suspend, ^resume, ^cleanup
|
/// async.coro.suspend %ret, ^suspend, ^resume, ^cleanup
|
||||||
/// ^resume:
|
/// ^resume:
|
||||||
/// %0 = async.runtime.load %value
|
/// %0 = async.runtime.load %value
|
||||||
/// br ^cleanup
|
/// cf.br ^cleanup
|
||||||
/// ^cleanup:
|
/// ^cleanup:
|
||||||
/// ...
|
/// ...
|
||||||
/// ^suspend:
|
/// ^suspend:
|
||||||
|
@ -406,7 +407,7 @@ AsyncRuntimeRefCountingPass::addDropRefInDivergentLivenessSuccessor(
|
||||||
refCountingBlock = &successor->getParent()->emplaceBlock();
|
refCountingBlock = &successor->getParent()->emplaceBlock();
|
||||||
refCountingBlock->moveBefore(successor);
|
refCountingBlock->moveBefore(successor);
|
||||||
OpBuilder builder = OpBuilder::atBlockEnd(refCountingBlock);
|
OpBuilder builder = OpBuilder::atBlockEnd(refCountingBlock);
|
||||||
builder.create<BranchOp>(value.getLoc(), successor);
|
builder.create<cf::BranchOp>(value.getLoc(), successor);
|
||||||
}
|
}
|
||||||
|
|
||||||
OpBuilder builder = OpBuilder::atBlockBegin(refCountingBlock);
|
OpBuilder builder = OpBuilder::atBlockBegin(refCountingBlock);
|
||||||
|
|
|
@ -12,10 +12,11 @@
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "PassDetail.h"
|
#include "PassDetail.h"
|
||||||
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
|
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
|
||||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||||
#include "mlir/Dialect/Async/IR/Async.h"
|
#include "mlir/Dialect/Async/IR/Async.h"
|
||||||
#include "mlir/Dialect/Async/Passes.h"
|
#include "mlir/Dialect/Async/Passes.h"
|
||||||
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||||
#include "mlir/Dialect/SCF/SCF.h"
|
#include "mlir/Dialect/SCF/SCF.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
#include "mlir/IR/BlockAndValueMapping.h"
|
#include "mlir/IR/BlockAndValueMapping.h"
|
||||||
|
@ -105,18 +106,18 @@ struct CoroMachinery {
|
||||||
/// %value = <async value> : !async.value<T> // create async value
|
/// %value = <async value> : !async.value<T> // create async value
|
||||||
/// %id = async.coro.id // create a coroutine id
|
/// %id = async.coro.id // create a coroutine id
|
||||||
/// %hdl = async.coro.begin %id // create a coroutine handle
|
/// %hdl = async.coro.begin %id // create a coroutine handle
|
||||||
/// br ^preexisting_entry_block
|
/// cf.br ^preexisting_entry_block
|
||||||
///
|
///
|
||||||
/// /* preexisting blocks modified to branch to the cleanup block */
|
/// /* preexisting blocks modified to branch to the cleanup block */
|
||||||
///
|
///
|
||||||
/// ^set_error: // this block created lazily only if needed (see code below)
|
/// ^set_error: // this block created lazily only if needed (see code below)
|
||||||
/// async.runtime.set_error %token : !async.token
|
/// async.runtime.set_error %token : !async.token
|
||||||
/// async.runtime.set_error %value : !async.value<T>
|
/// async.runtime.set_error %value : !async.value<T>
|
||||||
/// br ^cleanup
|
/// cf.br ^cleanup
|
||||||
///
|
///
|
||||||
/// ^cleanup:
|
/// ^cleanup:
|
||||||
/// async.coro.free %hdl // delete the coroutine state
|
/// async.coro.free %hdl // delete the coroutine state
|
||||||
/// br ^suspend
|
/// cf.br ^suspend
|
||||||
///
|
///
|
||||||
/// ^suspend:
|
/// ^suspend:
|
||||||
/// async.coro.end %hdl // marks the end of a coroutine
|
/// async.coro.end %hdl // marks the end of a coroutine
|
||||||
|
@ -147,7 +148,7 @@ static CoroMachinery setupCoroMachinery(FuncOp func) {
|
||||||
auto coroIdOp = builder.create<CoroIdOp>(CoroIdType::get(ctx));
|
auto coroIdOp = builder.create<CoroIdOp>(CoroIdType::get(ctx));
|
||||||
auto coroHdlOp =
|
auto coroHdlOp =
|
||||||
builder.create<CoroBeginOp>(CoroHandleType::get(ctx), coroIdOp.id());
|
builder.create<CoroBeginOp>(CoroHandleType::get(ctx), coroIdOp.id());
|
||||||
builder.create<BranchOp>(originalEntryBlock);
|
builder.create<cf::BranchOp>(originalEntryBlock);
|
||||||
|
|
||||||
Block *cleanupBlock = func.addBlock();
|
Block *cleanupBlock = func.addBlock();
|
||||||
Block *suspendBlock = func.addBlock();
|
Block *suspendBlock = func.addBlock();
|
||||||
|
@ -159,7 +160,7 @@ static CoroMachinery setupCoroMachinery(FuncOp func) {
|
||||||
builder.create<CoroFreeOp>(coroIdOp.id(), coroHdlOp.handle());
|
builder.create<CoroFreeOp>(coroIdOp.id(), coroHdlOp.handle());
|
||||||
|
|
||||||
// Branch into the suspend block.
|
// Branch into the suspend block.
|
||||||
builder.create<BranchOp>(suspendBlock);
|
builder.create<cf::BranchOp>(suspendBlock);
|
||||||
|
|
||||||
// ------------------------------------------------------------------------ //
|
// ------------------------------------------------------------------------ //
|
||||||
// Coroutine suspend block: mark the end of a coroutine and return allocated
|
// Coroutine suspend block: mark the end of a coroutine and return allocated
|
||||||
|
@ -186,7 +187,7 @@ static CoroMachinery setupCoroMachinery(FuncOp func) {
|
||||||
Operation *terminator = block.getTerminator();
|
Operation *terminator = block.getTerminator();
|
||||||
if (auto yield = dyn_cast<YieldOp>(terminator)) {
|
if (auto yield = dyn_cast<YieldOp>(terminator)) {
|
||||||
builder.setInsertionPointToEnd(&block);
|
builder.setInsertionPointToEnd(&block);
|
||||||
builder.create<BranchOp>(cleanupBlock);
|
builder.create<cf::BranchOp>(cleanupBlock);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -227,7 +228,7 @@ static Block *setupSetErrorBlock(CoroMachinery &coro) {
|
||||||
builder.create<RuntimeSetErrorOp>(retValue);
|
builder.create<RuntimeSetErrorOp>(retValue);
|
||||||
|
|
||||||
// Branch into the cleanup block.
|
// Branch into the cleanup block.
|
||||||
builder.create<BranchOp>(coro.cleanup);
|
builder.create<cf::BranchOp>(coro.cleanup);
|
||||||
|
|
||||||
return coro.setError;
|
return coro.setError;
|
||||||
}
|
}
|
||||||
|
@ -305,7 +306,7 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
|
||||||
// Async resume operation (execution will be resumed in a thread managed by
|
// Async resume operation (execution will be resumed in a thread managed by
|
||||||
// the async runtime).
|
// the async runtime).
|
||||||
{
|
{
|
||||||
BranchOp branch = cast<BranchOp>(coro.entry->getTerminator());
|
cf::BranchOp branch = cast<cf::BranchOp>(coro.entry->getTerminator());
|
||||||
builder.setInsertionPointToEnd(coro.entry);
|
builder.setInsertionPointToEnd(coro.entry);
|
||||||
|
|
||||||
// Save the coroutine state: async.coro.save
|
// Save the coroutine state: async.coro.save
|
||||||
|
@ -419,8 +420,8 @@ public:
|
||||||
isError, builder.create<arith::ConstantOp>(
|
isError, builder.create<arith::ConstantOp>(
|
||||||
loc, i1, builder.getIntegerAttr(i1, 1)));
|
loc, i1, builder.getIntegerAttr(i1, 1)));
|
||||||
|
|
||||||
builder.create<AssertOp>(notError,
|
builder.create<cf::AssertOp>(notError,
|
||||||
"Awaited async operand is in error state");
|
"Awaited async operand is in error state");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Inside the coroutine we convert await operation into coroutine suspension
|
// Inside the coroutine we convert await operation into coroutine suspension
|
||||||
|
@ -452,11 +453,11 @@ public:
|
||||||
// Check if the awaited value is in the error state.
|
// Check if the awaited value is in the error state.
|
||||||
builder.setInsertionPointToStart(resume);
|
builder.setInsertionPointToStart(resume);
|
||||||
auto isError = builder.create<RuntimeIsErrorOp>(loc, i1, operand);
|
auto isError = builder.create<RuntimeIsErrorOp>(loc, i1, operand);
|
||||||
builder.create<CondBranchOp>(isError,
|
builder.create<cf::CondBranchOp>(isError,
|
||||||
/*trueDest=*/setupSetErrorBlock(coro),
|
/*trueDest=*/setupSetErrorBlock(coro),
|
||||||
/*trueArgs=*/ArrayRef<Value>(),
|
/*trueArgs=*/ArrayRef<Value>(),
|
||||||
/*falseDest=*/continuation,
|
/*falseDest=*/continuation,
|
||||||
/*falseArgs=*/ArrayRef<Value>());
|
/*falseArgs=*/ArrayRef<Value>());
|
||||||
|
|
||||||
// Make sure that replacement value will be constructed in the
|
// Make sure that replacement value will be constructed in the
|
||||||
// continuation block.
|
// continuation block.
|
||||||
|
@ -560,18 +561,18 @@ private:
|
||||||
};
|
};
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Convert std.assert operation to cond_br into `set_error` block.
|
// Convert std.assert operation to cf.cond_br into `set_error` block.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
class AssertOpLowering : public OpConversionPattern<AssertOp> {
|
class AssertOpLowering : public OpConversionPattern<cf::AssertOp> {
|
||||||
public:
|
public:
|
||||||
AssertOpLowering(MLIRContext *ctx,
|
AssertOpLowering(MLIRContext *ctx,
|
||||||
llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
|
llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
|
||||||
: OpConversionPattern<AssertOp>(ctx),
|
: OpConversionPattern<cf::AssertOp>(ctx),
|
||||||
outlinedFunctions(outlinedFunctions) {}
|
outlinedFunctions(outlinedFunctions) {}
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(AssertOp op, OpAdaptor adaptor,
|
matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
// Check if assert operation is inside the async coroutine function.
|
// Check if assert operation is inside the async coroutine function.
|
||||||
auto func = op->template getParentOfType<FuncOp>();
|
auto func = op->template getParentOfType<FuncOp>();
|
||||||
|
@ -585,11 +586,11 @@ public:
|
||||||
|
|
||||||
Block *cont = rewriter.splitBlock(op->getBlock(), Block::iterator(op));
|
Block *cont = rewriter.splitBlock(op->getBlock(), Block::iterator(op));
|
||||||
rewriter.setInsertionPointToEnd(cont->getPrevNode());
|
rewriter.setInsertionPointToEnd(cont->getPrevNode());
|
||||||
rewriter.create<CondBranchOp>(loc, adaptor.getArg(),
|
rewriter.create<cf::CondBranchOp>(loc, adaptor.getArg(),
|
||||||
/*trueDest=*/cont,
|
/*trueDest=*/cont,
|
||||||
/*trueArgs=*/ArrayRef<Value>(),
|
/*trueArgs=*/ArrayRef<Value>(),
|
||||||
/*falseDest=*/setupSetErrorBlock(coro),
|
/*falseDest=*/setupSetErrorBlock(coro),
|
||||||
/*falseArgs=*/ArrayRef<Value>());
|
/*falseArgs=*/ArrayRef<Value>());
|
||||||
rewriter.eraseOp(op);
|
rewriter.eraseOp(op);
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
|
@ -765,7 +766,7 @@ void AsyncToAsyncRuntimePass::runOnOperation() {
|
||||||
// and we have to make sure that structured control flow operations with async
|
// and we have to make sure that structured control flow operations with async
|
||||||
// operations in nested regions will be converted to branch-based control flow
|
// operations in nested regions will be converted to branch-based control flow
|
||||||
// before we add the coroutine basic blocks.
|
// before we add the coroutine basic blocks.
|
||||||
populateLoopToStdConversionPatterns(asyncPatterns);
|
populateSCFToControlFlowConversionPatterns(asyncPatterns);
|
||||||
|
|
||||||
// Async lowering does not use type converter because it must preserve all
|
// Async lowering does not use type converter because it must preserve all
|
||||||
// types for async.runtime operations.
|
// types for async.runtime operations.
|
||||||
|
@ -792,14 +793,15 @@ void AsyncToAsyncRuntimePass::runOnOperation() {
|
||||||
});
|
});
|
||||||
return !walkResult.wasInterrupted();
|
return !walkResult.wasInterrupted();
|
||||||
});
|
});
|
||||||
runtimeTarget.addLegalOp<AssertOp, arith::XOrIOp, arith::ConstantOp,
|
runtimeTarget.addLegalOp<cf::AssertOp, arith::XOrIOp, arith::ConstantOp,
|
||||||
ConstantOp, BranchOp, CondBranchOp>();
|
ConstantOp, cf::BranchOp, cf::CondBranchOp>();
|
||||||
|
|
||||||
// Assertions must be converted to runtime errors inside async functions.
|
// Assertions must be converted to runtime errors inside async functions.
|
||||||
runtimeTarget.addDynamicallyLegalOp<AssertOp>([&](AssertOp op) -> bool {
|
runtimeTarget.addDynamicallyLegalOp<cf::AssertOp>(
|
||||||
auto func = op->getParentOfType<FuncOp>();
|
[&](cf::AssertOp op) -> bool {
|
||||||
return outlinedFunctions.find(func) == outlinedFunctions.end();
|
auto func = op->getParentOfType<FuncOp>();
|
||||||
});
|
return outlinedFunctions.find(func) == outlinedFunctions.end();
|
||||||
|
});
|
||||||
|
|
||||||
if (eliminateBlockingAwaitOps)
|
if (eliminateBlockingAwaitOps)
|
||||||
runtimeTarget.addDynamicallyLegalOp<RuntimeAwaitOp>(
|
runtimeTarget.addDynamicallyLegalOp<RuntimeAwaitOp>(
|
||||||
|
|
|
@ -17,7 +17,7 @@ add_mlir_dialect_library(MLIRAsyncTransforms
|
||||||
MLIRIR
|
MLIRIR
|
||||||
MLIRPass
|
MLIRPass
|
||||||
MLIRSCF
|
MLIRSCF
|
||||||
MLIRSCFToStandard
|
MLIRSCFToControlFlow
|
||||||
MLIRStandard
|
MLIRStandard
|
||||||
MLIRTransforms
|
MLIRTransforms
|
||||||
MLIRTransformUtils
|
MLIRTransformUtils
|
||||||
|
|
|
@ -18,12 +18,12 @@
|
||||||
// (using the BufferViewFlowAnalysis class). Consider the following example:
|
// (using the BufferViewFlowAnalysis class). Consider the following example:
|
||||||
//
|
//
|
||||||
// ^bb0(%arg0):
|
// ^bb0(%arg0):
|
||||||
// cond_br %cond, ^bb1, ^bb2
|
// cf.cond_br %cond, ^bb1, ^bb2
|
||||||
// ^bb1:
|
// ^bb1:
|
||||||
// br ^exit(%arg0)
|
// cf.br ^exit(%arg0)
|
||||||
// ^bb2:
|
// ^bb2:
|
||||||
// %new_value = ...
|
// %new_value = ...
|
||||||
// br ^exit(%new_value)
|
// cf.br ^exit(%new_value)
|
||||||
// ^exit(%arg1):
|
// ^exit(%arg1):
|
||||||
// return %arg1;
|
// return %arg1;
|
||||||
//
|
//
|
||||||
|
|
|
@ -6,6 +6,7 @@ add_subdirectory(Async)
|
||||||
add_subdirectory(AMX)
|
add_subdirectory(AMX)
|
||||||
add_subdirectory(Bufferization)
|
add_subdirectory(Bufferization)
|
||||||
add_subdirectory(Complex)
|
add_subdirectory(Complex)
|
||||||
|
add_subdirectory(ControlFlow)
|
||||||
add_subdirectory(DLTI)
|
add_subdirectory(DLTI)
|
||||||
add_subdirectory(EmitC)
|
add_subdirectory(EmitC)
|
||||||
add_subdirectory(GPU)
|
add_subdirectory(GPU)
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
add_subdirectory(IR)
|
|
@ -0,0 +1,15 @@
|
||||||
|
add_mlir_dialect_library(MLIRControlFlow
|
||||||
|
ControlFlowOps.cpp
|
||||||
|
|
||||||
|
ADDITIONAL_HEADER_DIRS
|
||||||
|
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ControlFlow/IR
|
||||||
|
|
||||||
|
DEPENDS
|
||||||
|
MLIRControlFlowOpsIncGen
|
||||||
|
|
||||||
|
LINK_LIBS PUBLIC
|
||||||
|
MLIRArithmetic
|
||||||
|
MLIRControlFlowInterfaces
|
||||||
|
MLIRIR
|
||||||
|
MLIRSideEffectInterfaces
|
||||||
|
)
|
|
@ -0,0 +1,891 @@
|
||||||
|
//===- ControlFlowOps.cpp - ControlFlow Operations ------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||||
|
#include "mlir/Dialect/CommonFolders.h"
|
||||||
|
#include "mlir/IR/AffineExpr.h"
|
||||||
|
#include "mlir/IR/AffineMap.h"
|
||||||
|
#include "mlir/IR/BlockAndValueMapping.h"
|
||||||
|
#include "mlir/IR/Builders.h"
|
||||||
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
#include "mlir/IR/Matchers.h"
|
||||||
|
#include "mlir/IR/OpImplementation.h"
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
#include "mlir/IR/TypeUtilities.h"
|
||||||
|
#include "mlir/IR/Value.h"
|
||||||
|
#include "mlir/Support/MathExtras.h"
|
||||||
|
#include "mlir/Transforms/InliningUtils.h"
|
||||||
|
#include "llvm/ADT/APFloat.h"
|
||||||
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
#include "llvm/ADT/StringSwitch.h"
|
||||||
|
#include "llvm/Support/FormatVariadic.h"
|
||||||
|
#include "llvm/Support/raw_ostream.h"
|
||||||
|
#include <numeric>
|
||||||
|
|
||||||
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOpsDialect.cpp.inc"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::cf;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ControlFlowDialect Interfaces
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
namespace {
|
||||||
|
/// This class defines the interface for handling inlining with control flow
|
||||||
|
/// operations.
|
||||||
|
struct ControlFlowInlinerInterface : public DialectInlinerInterface {
|
||||||
|
using DialectInlinerInterface::DialectInlinerInterface;
|
||||||
|
~ControlFlowInlinerInterface() override = default;
|
||||||
|
|
||||||
|
/// All control flow operations can be inlined.
|
||||||
|
bool isLegalToInline(Operation *call, Operation *callable,
|
||||||
|
bool wouldBeCloned) const final {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
bool isLegalToInline(Operation *, Region *, bool,
|
||||||
|
BlockAndValueMapping &) const final {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// ControlFlow terminator operations don't really need any special handing.
|
||||||
|
void handleTerminator(Operation *op, Block *newDest) const final {}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ControlFlowDialect
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
void ControlFlowDialect::initialize() {
|
||||||
|
addOperations<
|
||||||
|
#define GET_OP_LIST
|
||||||
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
|
||||||
|
>();
|
||||||
|
addInterfaces<ControlFlowInlinerInterface>();
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// AssertOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) {
|
||||||
|
// Erase assertion if argument is constant true.
|
||||||
|
if (matchPattern(op.getArg(), m_One())) {
|
||||||
|
rewriter.eraseOp(op);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// BranchOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
/// Given a successor, try to collapse it to a new destination if it only
|
||||||
|
/// contains a passthrough unconditional branch. If the successor is
|
||||||
|
/// collapsable, `successor` and `successorOperands` are updated to reference
|
||||||
|
/// the new destination and values. `argStorage` is used as storage if operands
|
||||||
|
/// to the collapsed successor need to be remapped. It must outlive uses of
|
||||||
|
/// successorOperands.
|
||||||
|
static LogicalResult collapseBranch(Block *&successor,
|
||||||
|
ValueRange &successorOperands,
|
||||||
|
SmallVectorImpl<Value> &argStorage) {
|
||||||
|
// Check that the successor only contains a unconditional branch.
|
||||||
|
if (std::next(successor->begin()) != successor->end())
|
||||||
|
return failure();
|
||||||
|
// Check that the terminator is an unconditional branch.
|
||||||
|
BranchOp successorBranch = dyn_cast<BranchOp>(successor->getTerminator());
|
||||||
|
if (!successorBranch)
|
||||||
|
return failure();
|
||||||
|
// Check that the arguments are only used within the terminator.
|
||||||
|
for (BlockArgument arg : successor->getArguments()) {
|
||||||
|
for (Operation *user : arg.getUsers())
|
||||||
|
if (user != successorBranch)
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
// Don't try to collapse branches to infinite loops.
|
||||||
|
Block *successorDest = successorBranch.getDest();
|
||||||
|
if (successorDest == successor)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Update the operands to the successor. If the branch parent has no
|
||||||
|
// arguments, we can use the branch operands directly.
|
||||||
|
OperandRange operands = successorBranch.getOperands();
|
||||||
|
if (successor->args_empty()) {
|
||||||
|
successor = successorDest;
|
||||||
|
successorOperands = operands;
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise, we need to remap any argument operands.
|
||||||
|
for (Value operand : operands) {
|
||||||
|
BlockArgument argOperand = operand.dyn_cast<BlockArgument>();
|
||||||
|
if (argOperand && argOperand.getOwner() == successor)
|
||||||
|
argStorage.push_back(successorOperands[argOperand.getArgNumber()]);
|
||||||
|
else
|
||||||
|
argStorage.push_back(operand);
|
||||||
|
}
|
||||||
|
successor = successorDest;
|
||||||
|
successorOperands = argStorage;
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Simplify a branch to a block that has a single predecessor. This effectively
|
||||||
|
/// merges the two blocks.
|
||||||
|
static LogicalResult
|
||||||
|
simplifyBrToBlockWithSinglePred(BranchOp op, PatternRewriter &rewriter) {
|
||||||
|
// Check that the successor block has a single predecessor.
|
||||||
|
Block *succ = op.getDest();
|
||||||
|
Block *opParent = op->getBlock();
|
||||||
|
if (succ == opParent || !llvm::hasSingleElement(succ->getPredecessors()))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Merge the successor into the current block and erase the branch.
|
||||||
|
rewriter.mergeBlocks(succ, opParent, op.getOperands());
|
||||||
|
rewriter.eraseOp(op);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// br ^bb1
|
||||||
|
/// ^bb1
|
||||||
|
/// br ^bbN(...)
|
||||||
|
///
|
||||||
|
/// -> br ^bbN(...)
|
||||||
|
///
|
||||||
|
static LogicalResult simplifyPassThroughBr(BranchOp op,
|
||||||
|
PatternRewriter &rewriter) {
|
||||||
|
Block *dest = op.getDest();
|
||||||
|
ValueRange destOperands = op.getOperands();
|
||||||
|
SmallVector<Value, 4> destOperandStorage;
|
||||||
|
|
||||||
|
// Try to collapse the successor if it points somewhere other than this
|
||||||
|
// block.
|
||||||
|
if (dest == op->getBlock() ||
|
||||||
|
failed(collapseBranch(dest, destOperands, destOperandStorage)))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Create a new branch with the collapsed successor.
|
||||||
|
rewriter.replaceOpWithNewOp<BranchOp>(op, dest, destOperands);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult BranchOp::canonicalize(BranchOp op, PatternRewriter &rewriter) {
|
||||||
|
return success(succeeded(simplifyBrToBlockWithSinglePred(op, rewriter)) ||
|
||||||
|
succeeded(simplifyPassThroughBr(op, rewriter)));
|
||||||
|
}
|
||||||
|
|
||||||
|
void BranchOp::setDest(Block *block) { return setSuccessor(block); }
|
||||||
|
|
||||||
|
void BranchOp::eraseOperand(unsigned index) { (*this)->eraseOperand(index); }
|
||||||
|
|
||||||
|
Optional<MutableOperandRange>
|
||||||
|
BranchOp::getMutableSuccessorOperands(unsigned index) {
|
||||||
|
assert(index == 0 && "invalid successor index");
|
||||||
|
return getDestOperandsMutable();
|
||||||
|
}
|
||||||
|
|
||||||
|
Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) {
|
||||||
|
return getDest();
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// CondBranchOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
/// cf.cond_br true, ^bb1, ^bb2
|
||||||
|
/// -> br ^bb1
|
||||||
|
/// cf.cond_br false, ^bb1, ^bb2
|
||||||
|
/// -> br ^bb2
|
||||||
|
///
|
||||||
|
struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
|
||||||
|
using OpRewritePattern<CondBranchOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(CondBranchOp condbr,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
if (matchPattern(condbr.getCondition(), m_NonZero())) {
|
||||||
|
// True branch taken.
|
||||||
|
rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
|
||||||
|
condbr.getTrueOperands());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
if (matchPattern(condbr.getCondition(), m_Zero())) {
|
||||||
|
// False branch taken.
|
||||||
|
rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
|
||||||
|
condbr.getFalseOperands());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/// cf.cond_br %cond, ^bb1, ^bb2
|
||||||
|
/// ^bb1
|
||||||
|
/// br ^bbN(...)
|
||||||
|
/// ^bb2
|
||||||
|
/// br ^bbK(...)
|
||||||
|
///
|
||||||
|
/// -> cf.cond_br %cond, ^bbN(...), ^bbK(...)
|
||||||
|
///
|
||||||
|
struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> {
|
||||||
|
using OpRewritePattern<CondBranchOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(CondBranchOp condbr,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
Block *trueDest = condbr.getTrueDest(), *falseDest = condbr.getFalseDest();
|
||||||
|
ValueRange trueDestOperands = condbr.getTrueOperands();
|
||||||
|
ValueRange falseDestOperands = condbr.getFalseOperands();
|
||||||
|
SmallVector<Value, 4> trueDestOperandStorage, falseDestOperandStorage;
|
||||||
|
|
||||||
|
// Try to collapse one of the current successors.
|
||||||
|
LogicalResult collapsedTrue =
|
||||||
|
collapseBranch(trueDest, trueDestOperands, trueDestOperandStorage);
|
||||||
|
LogicalResult collapsedFalse =
|
||||||
|
collapseBranch(falseDest, falseDestOperands, falseDestOperandStorage);
|
||||||
|
if (failed(collapsedTrue) && failed(collapsedFalse))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Create a new branch with the collapsed successors.
|
||||||
|
rewriter.replaceOpWithNewOp<CondBranchOp>(condbr, condbr.getCondition(),
|
||||||
|
trueDest, trueDestOperands,
|
||||||
|
falseDest, falseDestOperands);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/// cf.cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N)
|
||||||
|
/// -> br ^bb1(A, ..., N)
|
||||||
|
///
|
||||||
|
/// cf.cond_br %cond, ^bb1(A), ^bb1(B)
|
||||||
|
/// -> %select = arith.select %cond, A, B
|
||||||
|
/// br ^bb1(%select)
|
||||||
|
///
|
||||||
|
struct SimplifyCondBranchIdenticalSuccessors
|
||||||
|
: public OpRewritePattern<CondBranchOp> {
|
||||||
|
using OpRewritePattern<CondBranchOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(CondBranchOp condbr,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
// Check that the true and false destinations are the same and have the same
|
||||||
|
// operands.
|
||||||
|
Block *trueDest = condbr.getTrueDest();
|
||||||
|
if (trueDest != condbr.getFalseDest())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// If all of the operands match, no selects need to be generated.
|
||||||
|
OperandRange trueOperands = condbr.getTrueOperands();
|
||||||
|
OperandRange falseOperands = condbr.getFalseOperands();
|
||||||
|
if (trueOperands == falseOperands) {
|
||||||
|
rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, trueOperands);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise, if the current block is the only predecessor insert selects
|
||||||
|
// for any mismatched branch operands.
|
||||||
|
if (trueDest->getUniquePredecessor() != condbr->getBlock())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Generate a select for any operands that differ between the two.
|
||||||
|
SmallVector<Value, 8> mergedOperands;
|
||||||
|
mergedOperands.reserve(trueOperands.size());
|
||||||
|
Value condition = condbr.getCondition();
|
||||||
|
for (auto it : llvm::zip(trueOperands, falseOperands)) {
|
||||||
|
if (std::get<0>(it) == std::get<1>(it))
|
||||||
|
mergedOperands.push_back(std::get<0>(it));
|
||||||
|
else
|
||||||
|
mergedOperands.push_back(rewriter.create<arith::SelectOp>(
|
||||||
|
condbr.getLoc(), condition, std::get<0>(it), std::get<1>(it)));
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, mergedOperands);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/// ...
|
||||||
|
/// cf.cond_br %cond, ^bb1(...), ^bb2(...)
|
||||||
|
/// ...
|
||||||
|
/// ^bb1: // has single predecessor
|
||||||
|
/// ...
|
||||||
|
/// cf.cond_br %cond, ^bb3(...), ^bb4(...)
|
||||||
|
///
|
||||||
|
/// ->
|
||||||
|
///
|
||||||
|
/// ...
|
||||||
|
/// cf.cond_br %cond, ^bb1(...), ^bb2(...)
|
||||||
|
/// ...
|
||||||
|
/// ^bb1: // has single predecessor
|
||||||
|
/// ...
|
||||||
|
/// br ^bb3(...)
|
||||||
|
///
|
||||||
|
struct SimplifyCondBranchFromCondBranchOnSameCondition
|
||||||
|
: public OpRewritePattern<CondBranchOp> {
|
||||||
|
using OpRewritePattern<CondBranchOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(CondBranchOp condbr,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
// Check that we have a single distinct predecessor.
|
||||||
|
Block *currentBlock = condbr->getBlock();
|
||||||
|
Block *predecessor = currentBlock->getSinglePredecessor();
|
||||||
|
if (!predecessor)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Check that the predecessor terminates with a conditional branch to this
|
||||||
|
// block and that it branches on the same condition.
|
||||||
|
auto predBranch = dyn_cast<CondBranchOp>(predecessor->getTerminator());
|
||||||
|
if (!predBranch || condbr.getCondition() != predBranch.getCondition())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Fold this branch to an unconditional branch.
|
||||||
|
if (currentBlock == predBranch.getTrueDest())
|
||||||
|
rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
|
||||||
|
condbr.getTrueDestOperands());
|
||||||
|
else
|
||||||
|
rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
|
||||||
|
condbr.getFalseDestOperands());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/// cf.cond_br %arg0, ^trueB, ^falseB
|
||||||
|
///
|
||||||
|
/// ^trueB:
|
||||||
|
/// "test.consumer1"(%arg0) : (i1) -> ()
|
||||||
|
/// ...
|
||||||
|
///
|
||||||
|
/// ^falseB:
|
||||||
|
/// "test.consumer2"(%arg0) : (i1) -> ()
|
||||||
|
/// ...
|
||||||
|
///
|
||||||
|
/// ->
|
||||||
|
///
|
||||||
|
/// cf.cond_br %arg0, ^trueB, ^falseB
|
||||||
|
/// ^trueB:
|
||||||
|
/// "test.consumer1"(%true) : (i1) -> ()
|
||||||
|
/// ...
|
||||||
|
///
|
||||||
|
/// ^falseB:
|
||||||
|
/// "test.consumer2"(%false) : (i1) -> ()
|
||||||
|
/// ...
|
||||||
|
struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> {
|
||||||
|
using OpRewritePattern<CondBranchOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(CondBranchOp condbr,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
// Check that we have a single distinct predecessor.
|
||||||
|
bool replaced = false;
|
||||||
|
Type ty = rewriter.getI1Type();
|
||||||
|
|
||||||
|
// These variables serve to prevent creating duplicate constants
|
||||||
|
// and hold constant true or false values.
|
||||||
|
Value constantTrue = nullptr;
|
||||||
|
Value constantFalse = nullptr;
|
||||||
|
|
||||||
|
// TODO These checks can be expanded to encompas any use with only
|
||||||
|
// either the true of false edge as a predecessor. For now, we fall
|
||||||
|
// back to checking the single predecessor is given by the true/fasle
|
||||||
|
// destination, thereby ensuring that only that edge can reach the
|
||||||
|
// op.
|
||||||
|
if (condbr.getTrueDest()->getSinglePredecessor()) {
|
||||||
|
for (OpOperand &use :
|
||||||
|
llvm::make_early_inc_range(condbr.getCondition().getUses())) {
|
||||||
|
if (use.getOwner()->getBlock() == condbr.getTrueDest()) {
|
||||||
|
replaced = true;
|
||||||
|
|
||||||
|
if (!constantTrue)
|
||||||
|
constantTrue = rewriter.create<arith::ConstantOp>(
|
||||||
|
condbr.getLoc(), ty, rewriter.getBoolAttr(true));
|
||||||
|
|
||||||
|
rewriter.updateRootInPlace(use.getOwner(),
|
||||||
|
[&] { use.set(constantTrue); });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (condbr.getFalseDest()->getSinglePredecessor()) {
|
||||||
|
for (OpOperand &use :
|
||||||
|
llvm::make_early_inc_range(condbr.getCondition().getUses())) {
|
||||||
|
if (use.getOwner()->getBlock() == condbr.getFalseDest()) {
|
||||||
|
replaced = true;
|
||||||
|
|
||||||
|
if (!constantFalse)
|
||||||
|
constantFalse = rewriter.create<arith::ConstantOp>(
|
||||||
|
condbr.getLoc(), ty, rewriter.getBoolAttr(false));
|
||||||
|
|
||||||
|
rewriter.updateRootInPlace(use.getOwner(),
|
||||||
|
[&] { use.set(constantFalse); });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return success(replaced);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||||
|
MLIRContext *context) {
|
||||||
|
results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
|
||||||
|
SimplifyCondBranchIdenticalSuccessors,
|
||||||
|
SimplifyCondBranchFromCondBranchOnSameCondition,
|
||||||
|
CondBranchTruthPropagation>(context);
|
||||||
|
}
|
||||||
|
|
||||||
|
Optional<MutableOperandRange>
|
||||||
|
CondBranchOp::getMutableSuccessorOperands(unsigned index) {
|
||||||
|
assert(index < getNumSuccessors() && "invalid successor index");
|
||||||
|
return index == trueIndex ? getTrueDestOperandsMutable()
|
||||||
|
: getFalseDestOperandsMutable();
|
||||||
|
}
|
||||||
|
|
||||||
|
Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
|
||||||
|
if (IntegerAttr condAttr = operands.front().dyn_cast_or_null<IntegerAttr>())
|
||||||
|
return condAttr.getValue().isOneValue() ? getTrueDest() : getFalseDest();
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// SwitchOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
|
||||||
|
Block *defaultDestination, ValueRange defaultOperands,
|
||||||
|
DenseIntElementsAttr caseValues,
|
||||||
|
BlockRange caseDestinations,
|
||||||
|
ArrayRef<ValueRange> caseOperands) {
|
||||||
|
build(builder, result, value, defaultOperands, caseOperands, caseValues,
|
||||||
|
defaultDestination, caseDestinations);
|
||||||
|
}
|
||||||
|
|
||||||
|
void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
|
||||||
|
Block *defaultDestination, ValueRange defaultOperands,
|
||||||
|
ArrayRef<APInt> caseValues, BlockRange caseDestinations,
|
||||||
|
ArrayRef<ValueRange> caseOperands) {
|
||||||
|
DenseIntElementsAttr caseValuesAttr;
|
||||||
|
if (!caseValues.empty()) {
|
||||||
|
ShapedType caseValueType = VectorType::get(
|
||||||
|
static_cast<int64_t>(caseValues.size()), value.getType());
|
||||||
|
caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
|
||||||
|
}
|
||||||
|
build(builder, result, value, defaultDestination, defaultOperands,
|
||||||
|
caseValuesAttr, caseDestinations, caseOperands);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <cases> ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)?
|
||||||
|
/// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )*
|
||||||
|
static ParseResult parseSwitchOpCases(
|
||||||
|
OpAsmParser &parser, Type &flagType, Block *&defaultDestination,
|
||||||
|
SmallVectorImpl<OpAsmParser::OperandType> &defaultOperands,
|
||||||
|
SmallVectorImpl<Type> &defaultOperandTypes,
|
||||||
|
DenseIntElementsAttr &caseValues,
|
||||||
|
SmallVectorImpl<Block *> &caseDestinations,
|
||||||
|
SmallVectorImpl<SmallVector<OpAsmParser::OperandType>> &caseOperands,
|
||||||
|
SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) {
|
||||||
|
if (parser.parseKeyword("default") || parser.parseColon() ||
|
||||||
|
parser.parseSuccessor(defaultDestination))
|
||||||
|
return failure();
|
||||||
|
if (succeeded(parser.parseOptionalLParen())) {
|
||||||
|
if (parser.parseRegionArgumentList(defaultOperands) ||
|
||||||
|
parser.parseColonTypeList(defaultOperandTypes) || parser.parseRParen())
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<APInt> values;
|
||||||
|
unsigned bitWidth = flagType.getIntOrFloatBitWidth();
|
||||||
|
while (succeeded(parser.parseOptionalComma())) {
|
||||||
|
int64_t value = 0;
|
||||||
|
if (failed(parser.parseInteger(value)))
|
||||||
|
return failure();
|
||||||
|
values.push_back(APInt(bitWidth, value));
|
||||||
|
|
||||||
|
Block *destination;
|
||||||
|
SmallVector<OpAsmParser::OperandType> operands;
|
||||||
|
SmallVector<Type> operandTypes;
|
||||||
|
if (failed(parser.parseColon()) ||
|
||||||
|
failed(parser.parseSuccessor(destination)))
|
||||||
|
return failure();
|
||||||
|
if (succeeded(parser.parseOptionalLParen())) {
|
||||||
|
if (failed(parser.parseRegionArgumentList(operands)) ||
|
||||||
|
failed(parser.parseColonTypeList(operandTypes)) ||
|
||||||
|
failed(parser.parseRParen()))
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
caseDestinations.push_back(destination);
|
||||||
|
caseOperands.emplace_back(operands);
|
||||||
|
caseOperandTypes.emplace_back(operandTypes);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!values.empty()) {
|
||||||
|
ShapedType caseValueType =
|
||||||
|
VectorType::get(static_cast<int64_t>(values.size()), flagType);
|
||||||
|
caseValues = DenseIntElementsAttr::get(caseValueType, values);
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
static void printSwitchOpCases(
|
||||||
|
OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination,
|
||||||
|
OperandRange defaultOperands, TypeRange defaultOperandTypes,
|
||||||
|
DenseIntElementsAttr caseValues, SuccessorRange caseDestinations,
|
||||||
|
OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes) {
|
||||||
|
p << " default: ";
|
||||||
|
p.printSuccessorAndUseList(defaultDestination, defaultOperands);
|
||||||
|
|
||||||
|
if (!caseValues)
|
||||||
|
return;
|
||||||
|
|
||||||
|
for (const auto &it : llvm::enumerate(caseValues.getValues<APInt>())) {
|
||||||
|
p << ',';
|
||||||
|
p.printNewline();
|
||||||
|
p << " ";
|
||||||
|
p << it.value().getLimitedValue();
|
||||||
|
p << ": ";
|
||||||
|
p.printSuccessorAndUseList(caseDestinations[it.index()],
|
||||||
|
caseOperands[it.index()]);
|
||||||
|
}
|
||||||
|
p.printNewline();
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult SwitchOp::verify() {
|
||||||
|
auto caseValues = getCaseValues();
|
||||||
|
auto caseDestinations = getCaseDestinations();
|
||||||
|
|
||||||
|
if (!caseValues && caseDestinations.empty())
|
||||||
|
return success();
|
||||||
|
|
||||||
|
Type flagType = getFlag().getType();
|
||||||
|
Type caseValueType = caseValues->getType().getElementType();
|
||||||
|
if (caseValueType != flagType)
|
||||||
|
return emitOpError() << "'flag' type (" << flagType
|
||||||
|
<< ") should match case value type (" << caseValueType
|
||||||
|
<< ")";
|
||||||
|
|
||||||
|
if (caseValues &&
|
||||||
|
caseValues->size() != static_cast<int64_t>(caseDestinations.size()))
|
||||||
|
return emitOpError() << "number of case values (" << caseValues->size()
|
||||||
|
<< ") should match number of "
|
||||||
|
"case destinations ("
|
||||||
|
<< caseDestinations.size() << ")";
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
Optional<MutableOperandRange>
|
||||||
|
SwitchOp::getMutableSuccessorOperands(unsigned index) {
|
||||||
|
assert(index < getNumSuccessors() && "invalid successor index");
|
||||||
|
return index == 0 ? getDefaultOperandsMutable()
|
||||||
|
: getCaseOperandsMutable(index - 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
|
||||||
|
Optional<DenseIntElementsAttr> caseValues = getCaseValues();
|
||||||
|
|
||||||
|
if (!caseValues)
|
||||||
|
return getDefaultDestination();
|
||||||
|
|
||||||
|
SuccessorRange caseDests = getCaseDestinations();
|
||||||
|
if (auto value = operands.front().dyn_cast_or_null<IntegerAttr>()) {
|
||||||
|
for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>()))
|
||||||
|
if (it.value() == value.getValue())
|
||||||
|
return caseDests[it.index()];
|
||||||
|
return getDefaultDestination();
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// switch %flag : i32, [
|
||||||
|
/// default: ^bb1
|
||||||
|
/// ]
|
||||||
|
/// -> br ^bb1
|
||||||
|
static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op,
|
||||||
|
PatternRewriter &rewriter) {
|
||||||
|
if (!op.getCaseDestinations().empty())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
|
||||||
|
op.getDefaultOperands());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// switch %flag : i32, [
|
||||||
|
/// default: ^bb1,
|
||||||
|
/// 42: ^bb1,
|
||||||
|
/// 43: ^bb2
|
||||||
|
/// ]
|
||||||
|
/// ->
|
||||||
|
/// switch %flag : i32, [
|
||||||
|
/// default: ^bb1,
|
||||||
|
/// 43: ^bb2
|
||||||
|
/// ]
|
||||||
|
static LogicalResult
|
||||||
|
dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter) {
|
||||||
|
SmallVector<Block *> newCaseDestinations;
|
||||||
|
SmallVector<ValueRange> newCaseOperands;
|
||||||
|
SmallVector<APInt> newCaseValues;
|
||||||
|
bool requiresChange = false;
|
||||||
|
auto caseValues = op.getCaseValues();
|
||||||
|
auto caseDests = op.getCaseDestinations();
|
||||||
|
|
||||||
|
for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
|
||||||
|
if (caseDests[it.index()] == op.getDefaultDestination() &&
|
||||||
|
op.getCaseOperands(it.index()) == op.getDefaultOperands()) {
|
||||||
|
requiresChange = true;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
newCaseDestinations.push_back(caseDests[it.index()]);
|
||||||
|
newCaseOperands.push_back(op.getCaseOperands(it.index()));
|
||||||
|
newCaseValues.push_back(it.value());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!requiresChange)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<SwitchOp>(
|
||||||
|
op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
|
||||||
|
newCaseValues, newCaseDestinations, newCaseOperands);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper for folding a switch with a constant value.
|
||||||
|
/// switch %c_42 : i32, [
|
||||||
|
/// default: ^bb1 ,
|
||||||
|
/// 42: ^bb2,
|
||||||
|
/// 43: ^bb3
|
||||||
|
/// ]
|
||||||
|
/// -> br ^bb2
|
||||||
|
static void foldSwitch(SwitchOp op, PatternRewriter &rewriter,
|
||||||
|
const APInt &caseValue) {
|
||||||
|
auto caseValues = op.getCaseValues();
|
||||||
|
for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
|
||||||
|
if (it.value() == caseValue) {
|
||||||
|
rewriter.replaceOpWithNewOp<BranchOp>(
|
||||||
|
op, op.getCaseDestinations()[it.index()],
|
||||||
|
op.getCaseOperands(it.index()));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
|
||||||
|
op.getDefaultOperands());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// switch %c_42 : i32, [
|
||||||
|
/// default: ^bb1,
|
||||||
|
/// 42: ^bb2,
|
||||||
|
/// 43: ^bb3
|
||||||
|
/// ]
|
||||||
|
/// -> br ^bb2
|
||||||
|
static LogicalResult simplifyConstSwitchValue(SwitchOp op,
|
||||||
|
PatternRewriter &rewriter) {
|
||||||
|
APInt caseValue;
|
||||||
|
if (!matchPattern(op.getFlag(), m_ConstantInt(&caseValue)))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
foldSwitch(op, rewriter, caseValue);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// switch %c_42 : i32, [
|
||||||
|
/// default: ^bb1,
|
||||||
|
/// 42: ^bb2,
|
||||||
|
/// ]
|
||||||
|
/// ^bb2:
|
||||||
|
/// br ^bb3
|
||||||
|
/// ->
|
||||||
|
/// switch %c_42 : i32, [
|
||||||
|
/// default: ^bb1,
|
||||||
|
/// 42: ^bb3,
|
||||||
|
/// ]
|
||||||
|
static LogicalResult simplifyPassThroughSwitch(SwitchOp op,
|
||||||
|
PatternRewriter &rewriter) {
|
||||||
|
SmallVector<Block *> newCaseDests;
|
||||||
|
SmallVector<ValueRange> newCaseOperands;
|
||||||
|
SmallVector<SmallVector<Value>> argStorage;
|
||||||
|
auto caseValues = op.getCaseValues();
|
||||||
|
auto caseDests = op.getCaseDestinations();
|
||||||
|
bool requiresChange = false;
|
||||||
|
for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
|
||||||
|
Block *caseDest = caseDests[i];
|
||||||
|
ValueRange caseOperands = op.getCaseOperands(i);
|
||||||
|
argStorage.emplace_back();
|
||||||
|
if (succeeded(collapseBranch(caseDest, caseOperands, argStorage.back())))
|
||||||
|
requiresChange = true;
|
||||||
|
|
||||||
|
newCaseDests.push_back(caseDest);
|
||||||
|
newCaseOperands.push_back(caseOperands);
|
||||||
|
}
|
||||||
|
|
||||||
|
Block *defaultDest = op.getDefaultDestination();
|
||||||
|
ValueRange defaultOperands = op.getDefaultOperands();
|
||||||
|
argStorage.emplace_back();
|
||||||
|
|
||||||
|
if (succeeded(
|
||||||
|
collapseBranch(defaultDest, defaultOperands, argStorage.back())))
|
||||||
|
requiresChange = true;
|
||||||
|
|
||||||
|
if (!requiresChange)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<SwitchOp>(op, op.getFlag(), defaultDest,
|
||||||
|
defaultOperands, caseValues.getValue(),
|
||||||
|
newCaseDests, newCaseOperands);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// switch %flag : i32, [
|
||||||
|
/// default: ^bb1,
|
||||||
|
/// 42: ^bb2,
|
||||||
|
/// ]
|
||||||
|
/// ^bb2:
|
||||||
|
/// switch %flag : i32, [
|
||||||
|
/// default: ^bb3,
|
||||||
|
/// 42: ^bb4
|
||||||
|
/// ]
|
||||||
|
/// ->
|
||||||
|
/// switch %flag : i32, [
|
||||||
|
/// default: ^bb1,
|
||||||
|
/// 42: ^bb2,
|
||||||
|
/// ]
|
||||||
|
/// ^bb2:
|
||||||
|
/// br ^bb4
|
||||||
|
///
|
||||||
|
/// and
|
||||||
|
///
|
||||||
|
/// switch %flag : i32, [
|
||||||
|
/// default: ^bb1,
|
||||||
|
/// 42: ^bb2,
|
||||||
|
/// ]
|
||||||
|
/// ^bb2:
|
||||||
|
/// switch %flag : i32, [
|
||||||
|
/// default: ^bb3,
|
||||||
|
/// 43: ^bb4
|
||||||
|
/// ]
|
||||||
|
/// ->
|
||||||
|
/// switch %flag : i32, [
|
||||||
|
/// default: ^bb1,
|
||||||
|
/// 42: ^bb2,
|
||||||
|
/// ]
|
||||||
|
/// ^bb2:
|
||||||
|
/// br ^bb3
|
||||||
|
static LogicalResult
|
||||||
|
simplifySwitchFromSwitchOnSameCondition(SwitchOp op,
|
||||||
|
PatternRewriter &rewriter) {
|
||||||
|
// Check that we have a single distinct predecessor.
|
||||||
|
Block *currentBlock = op->getBlock();
|
||||||
|
Block *predecessor = currentBlock->getSinglePredecessor();
|
||||||
|
if (!predecessor)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Check that the predecessor terminates with a switch branch to this block
|
||||||
|
// and that it branches on the same condition and that this branch isn't the
|
||||||
|
// default destination.
|
||||||
|
auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
|
||||||
|
if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
|
||||||
|
predSwitch.getDefaultDestination() == currentBlock)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Fold this switch to an unconditional branch.
|
||||||
|
SuccessorRange predDests = predSwitch.getCaseDestinations();
|
||||||
|
auto it = llvm::find(predDests, currentBlock);
|
||||||
|
if (it != predDests.end()) {
|
||||||
|
Optional<DenseIntElementsAttr> predCaseValues = predSwitch.getCaseValues();
|
||||||
|
foldSwitch(op, rewriter,
|
||||||
|
predCaseValues->getValues<APInt>()[it - predDests.begin()]);
|
||||||
|
} else {
|
||||||
|
rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
|
||||||
|
op.getDefaultOperands());
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// switch %flag : i32, [
|
||||||
|
/// default: ^bb1,
|
||||||
|
/// 42: ^bb2
|
||||||
|
/// ]
|
||||||
|
/// ^bb1:
|
||||||
|
/// switch %flag : i32, [
|
||||||
|
/// default: ^bb3,
|
||||||
|
/// 42: ^bb4,
|
||||||
|
/// 43: ^bb5
|
||||||
|
/// ]
|
||||||
|
/// ->
|
||||||
|
/// switch %flag : i32, [
|
||||||
|
/// default: ^bb1,
|
||||||
|
/// 42: ^bb2,
|
||||||
|
/// ]
|
||||||
|
/// ^bb1:
|
||||||
|
/// switch %flag : i32, [
|
||||||
|
/// default: ^bb3,
|
||||||
|
/// 43: ^bb5
|
||||||
|
/// ]
|
||||||
|
static LogicalResult
|
||||||
|
simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op,
|
||||||
|
PatternRewriter &rewriter) {
|
||||||
|
// Check that we have a single distinct predecessor.
|
||||||
|
Block *currentBlock = op->getBlock();
|
||||||
|
Block *predecessor = currentBlock->getSinglePredecessor();
|
||||||
|
if (!predecessor)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Check that the predecessor terminates with a switch branch to this block
|
||||||
|
// and that it branches on the same condition and that this branch is the
|
||||||
|
// default destination.
|
||||||
|
auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
|
||||||
|
if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
|
||||||
|
predSwitch.getDefaultDestination() != currentBlock)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Delete case values that are not possible here.
|
||||||
|
DenseSet<APInt> caseValuesToRemove;
|
||||||
|
auto predDests = predSwitch.getCaseDestinations();
|
||||||
|
auto predCaseValues = predSwitch.getCaseValues();
|
||||||
|
for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i)
|
||||||
|
if (currentBlock != predDests[i])
|
||||||
|
caseValuesToRemove.insert(predCaseValues->getValues<APInt>()[i]);
|
||||||
|
|
||||||
|
SmallVector<Block *> newCaseDestinations;
|
||||||
|
SmallVector<ValueRange> newCaseOperands;
|
||||||
|
SmallVector<APInt> newCaseValues;
|
||||||
|
bool requiresChange = false;
|
||||||
|
|
||||||
|
auto caseValues = op.getCaseValues();
|
||||||
|
auto caseDests = op.getCaseDestinations();
|
||||||
|
for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
|
||||||
|
if (caseValuesToRemove.contains(it.value())) {
|
||||||
|
requiresChange = true;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
newCaseDestinations.push_back(caseDests[it.index()]);
|
||||||
|
newCaseOperands.push_back(op.getCaseOperands(it.index()));
|
||||||
|
newCaseValues.push_back(it.value());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!requiresChange)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<SwitchOp>(
|
||||||
|
op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
|
||||||
|
newCaseValues, newCaseDestinations, newCaseOperands);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
void SwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||||
|
MLIRContext *context) {
|
||||||
|
results.add(&simplifySwitchWithOnlyDefault)
|
||||||
|
.add(&dropSwitchCasesThatMatchDefault)
|
||||||
|
.add(&simplifyConstSwitchValue)
|
||||||
|
.add(&simplifyPassThroughSwitch)
|
||||||
|
.add(&simplifySwitchFromSwitchOnSameCondition)
|
||||||
|
.add(&simplifySwitchFromDefaultSwitchOnSameCondition);
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// TableGen'd op method definitions
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#define GET_OP_CLASSES
|
||||||
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
|
|
@ -12,10 +12,10 @@
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||||
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||||
#include "mlir/Dialect/GPU/GPUDialect.h"
|
#include "mlir/Dialect/GPU/GPUDialect.h"
|
||||||
#include "mlir/Dialect/GPU/Passes.h"
|
#include "mlir/Dialect/GPU/Passes.h"
|
||||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
|
||||||
#include "mlir/IR/BlockAndValueMapping.h"
|
#include "mlir/IR/BlockAndValueMapping.h"
|
||||||
#include "mlir/IR/Builders.h"
|
#include "mlir/IR/Builders.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
@ -44,14 +44,14 @@ struct GpuAllReduceRewriter {
|
||||||
/// workgroup memory.
|
/// workgroup memory.
|
||||||
///
|
///
|
||||||
/// %subgroup_reduce = `createSubgroupReduce(%operand)`
|
/// %subgroup_reduce = `createSubgroupReduce(%operand)`
|
||||||
/// cond_br %is_first_lane, ^then1, ^continue1
|
/// cf.cond_br %is_first_lane, ^then1, ^continue1
|
||||||
/// ^then1:
|
/// ^then1:
|
||||||
/// store %subgroup_reduce, %workgroup_buffer[%subgroup_id]
|
/// store %subgroup_reduce, %workgroup_buffer[%subgroup_id]
|
||||||
/// br ^continue1
|
/// cf.br ^continue1
|
||||||
/// ^continue1:
|
/// ^continue1:
|
||||||
/// gpu.barrier
|
/// gpu.barrier
|
||||||
/// %is_valid_subgroup = arith.cmpi "slt" %invocation_idx, %num_subgroups
|
/// %is_valid_subgroup = arith.cmpi "slt" %invocation_idx, %num_subgroups
|
||||||
/// cond_br %is_valid_subgroup, ^then2, ^continue2
|
/// cf.cond_br %is_valid_subgroup, ^then2, ^continue2
|
||||||
/// ^then2:
|
/// ^then2:
|
||||||
/// %partial_reduce = load %workgroup_buffer[%invocation_idx]
|
/// %partial_reduce = load %workgroup_buffer[%invocation_idx]
|
||||||
/// %all_reduce = `createSubgroupReduce(%partial_reduce)`
|
/// %all_reduce = `createSubgroupReduce(%partial_reduce)`
|
||||||
|
@ -194,7 +194,7 @@ private:
|
||||||
|
|
||||||
// Add branch before inserted body, into body.
|
// Add branch before inserted body, into body.
|
||||||
block = block->getNextNode();
|
block = block->getNextNode();
|
||||||
create<BranchOp>(block, ValueRange());
|
create<cf::BranchOp>(block, ValueRange());
|
||||||
|
|
||||||
// Replace all gpu.yield ops with branch out of body.
|
// Replace all gpu.yield ops with branch out of body.
|
||||||
for (; block != split; block = block->getNextNode()) {
|
for (; block != split; block = block->getNextNode()) {
|
||||||
|
@ -202,7 +202,7 @@ private:
|
||||||
if (!isa<gpu::YieldOp>(terminator))
|
if (!isa<gpu::YieldOp>(terminator))
|
||||||
continue;
|
continue;
|
||||||
rewriter.setInsertionPointToEnd(block);
|
rewriter.setInsertionPointToEnd(block);
|
||||||
rewriter.replaceOpWithNewOp<BranchOp>(
|
rewriter.replaceOpWithNewOp<cf::BranchOp>(
|
||||||
terminator, split, ValueRange(terminator->getOperand(0)));
|
terminator, split, ValueRange(terminator->getOperand(0)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -285,17 +285,17 @@ private:
|
||||||
Block *continueBlock = rewriter.splitBlock(elseBlock, elseBlock->begin());
|
Block *continueBlock = rewriter.splitBlock(elseBlock, elseBlock->begin());
|
||||||
|
|
||||||
rewriter.setInsertionPointToEnd(currentBlock);
|
rewriter.setInsertionPointToEnd(currentBlock);
|
||||||
create<CondBranchOp>(condition, thenBlock,
|
create<cf::CondBranchOp>(condition, thenBlock,
|
||||||
/*trueOperands=*/ArrayRef<Value>(), elseBlock,
|
/*trueOperands=*/ArrayRef<Value>(), elseBlock,
|
||||||
/*falseOperands=*/ArrayRef<Value>());
|
/*falseOperands=*/ArrayRef<Value>());
|
||||||
|
|
||||||
rewriter.setInsertionPointToStart(thenBlock);
|
rewriter.setInsertionPointToStart(thenBlock);
|
||||||
auto thenOperands = thenOpsFactory();
|
auto thenOperands = thenOpsFactory();
|
||||||
create<BranchOp>(continueBlock, thenOperands);
|
create<cf::BranchOp>(continueBlock, thenOperands);
|
||||||
|
|
||||||
rewriter.setInsertionPointToStart(elseBlock);
|
rewriter.setInsertionPointToStart(elseBlock);
|
||||||
auto elseOperands = elseOpsFactory();
|
auto elseOperands = elseOpsFactory();
|
||||||
create<BranchOp>(continueBlock, elseOperands);
|
create<cf::BranchOp>(continueBlock, elseOperands);
|
||||||
|
|
||||||
assert(thenOperands.size() == elseOperands.size());
|
assert(thenOperands.size() == elseOperands.size());
|
||||||
rewriter.setInsertionPointToStart(continueBlock);
|
rewriter.setInsertionPointToStart(continueBlock);
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
|
|
||||||
#include "PassDetail.h"
|
#include "PassDetail.h"
|
||||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||||
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||||
#include "mlir/Dialect/DLTI/DLTI.h"
|
#include "mlir/Dialect/DLTI/DLTI.h"
|
||||||
#include "mlir/Dialect/GPU/GPUDialect.h"
|
#include "mlir/Dialect/GPU/GPUDialect.h"
|
||||||
#include "mlir/Dialect/GPU/Passes.h"
|
#include "mlir/Dialect/GPU/Passes.h"
|
||||||
|
@ -186,7 +187,7 @@ static gpu::GPUFuncOp outlineKernelFuncImpl(gpu::LaunchOp launchOp,
|
||||||
Block &launchOpEntry = launchOpBody.front();
|
Block &launchOpEntry = launchOpBody.front();
|
||||||
Block *clonedLaunchOpEntry = map.lookup(&launchOpEntry);
|
Block *clonedLaunchOpEntry = map.lookup(&launchOpEntry);
|
||||||
builder.setInsertionPointToEnd(&entryBlock);
|
builder.setInsertionPointToEnd(&entryBlock);
|
||||||
builder.create<BranchOp>(loc, clonedLaunchOpEntry);
|
builder.create<cf::BranchOp>(loc, clonedLaunchOpEntry);
|
||||||
|
|
||||||
outlinedFunc.walk([](gpu::TerminatorOp op) {
|
outlinedFunc.walk([](gpu::TerminatorOp op) {
|
||||||
OpBuilder replacer(op);
|
OpBuilder replacer(op);
|
||||||
|
|
|
@ -7,6 +7,7 @@
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "PassDetail.h"
|
#include "PassDetail.h"
|
||||||
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||||
#include "mlir/Dialect/Linalg/Passes.h"
|
#include "mlir/Dialect/Linalg/Passes.h"
|
||||||
#include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
|
#include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
|
||||||
|
@ -254,13 +255,13 @@ struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
|
||||||
DenseSet<BlockArgument> &blockArgsToDetensor) override {
|
DenseSet<BlockArgument> &blockArgsToDetensor) override {
|
||||||
SmallVector<Value> workList;
|
SmallVector<Value> workList;
|
||||||
|
|
||||||
func->walk([&](CondBranchOp condBr) {
|
func->walk([&](cf::CondBranchOp condBr) {
|
||||||
for (auto operand : condBr.getOperands()) {
|
for (auto operand : condBr.getOperands()) {
|
||||||
workList.push_back(operand);
|
workList.push_back(operand);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
func->walk([&](BranchOp br) {
|
func->walk([&](cf::BranchOp br) {
|
||||||
for (auto operand : br.getOperands()) {
|
for (auto operand : br.getOperands()) {
|
||||||
workList.push_back(operand);
|
workList.push_back(operand);
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,6 +9,7 @@
|
||||||
#include "mlir/Dialect/SCF/SCF.h"
|
#include "mlir/Dialect/SCF/SCF.h"
|
||||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||||
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
#include "mlir/IR/BlockAndValueMapping.h"
|
#include "mlir/IR/BlockAndValueMapping.h"
|
||||||
#include "mlir/IR/Matchers.h"
|
#include "mlir/IR/Matchers.h"
|
||||||
|
@ -165,13 +166,13 @@ struct SingleBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
|
||||||
// "test.foo"() : () -> ()
|
// "test.foo"() : () -> ()
|
||||||
// %v = scf.execute_region -> i64 {
|
// %v = scf.execute_region -> i64 {
|
||||||
// %c = "test.cmp"() : () -> i1
|
// %c = "test.cmp"() : () -> i1
|
||||||
// cond_br %c, ^bb2, ^bb3
|
// cf.cond_br %c, ^bb2, ^bb3
|
||||||
// ^bb2:
|
// ^bb2:
|
||||||
// %x = "test.val1"() : () -> i64
|
// %x = "test.val1"() : () -> i64
|
||||||
// br ^bb4(%x : i64)
|
// cf.br ^bb4(%x : i64)
|
||||||
// ^bb3:
|
// ^bb3:
|
||||||
// %y = "test.val2"() : () -> i64
|
// %y = "test.val2"() : () -> i64
|
||||||
// br ^bb4(%y : i64)
|
// cf.br ^bb4(%y : i64)
|
||||||
// ^bb4(%z : i64):
|
// ^bb4(%z : i64):
|
||||||
// scf.yield %z : i64
|
// scf.yield %z : i64
|
||||||
// }
|
// }
|
||||||
|
@ -184,13 +185,13 @@ struct SingleBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
|
||||||
// func @func_execute_region_elim() {
|
// func @func_execute_region_elim() {
|
||||||
// "test.foo"() : () -> ()
|
// "test.foo"() : () -> ()
|
||||||
// %c = "test.cmp"() : () -> i1
|
// %c = "test.cmp"() : () -> i1
|
||||||
// cond_br %c, ^bb1, ^bb2
|
// cf.cond_br %c, ^bb1, ^bb2
|
||||||
// ^bb1: // pred: ^bb0
|
// ^bb1: // pred: ^bb0
|
||||||
// %x = "test.val1"() : () -> i64
|
// %x = "test.val1"() : () -> i64
|
||||||
// br ^bb3(%x : i64)
|
// cf.br ^bb3(%x : i64)
|
||||||
// ^bb2: // pred: ^bb0
|
// ^bb2: // pred: ^bb0
|
||||||
// %y = "test.val2"() : () -> i64
|
// %y = "test.val2"() : () -> i64
|
||||||
// br ^bb3(%y : i64)
|
// cf.br ^bb3(%y : i64)
|
||||||
// ^bb3(%z: i64): // 2 preds: ^bb1, ^bb2
|
// ^bb3(%z: i64): // 2 preds: ^bb1, ^bb2
|
||||||
// "test.bar"(%z) : (i64) -> ()
|
// "test.bar"(%z) : (i64) -> ()
|
||||||
// return
|
// return
|
||||||
|
@ -208,13 +209,13 @@ struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
|
||||||
Block *postBlock = rewriter.splitBlock(prevBlock, op->getIterator());
|
Block *postBlock = rewriter.splitBlock(prevBlock, op->getIterator());
|
||||||
rewriter.setInsertionPointToEnd(prevBlock);
|
rewriter.setInsertionPointToEnd(prevBlock);
|
||||||
|
|
||||||
rewriter.create<BranchOp>(op.getLoc(), &op.getRegion().front());
|
rewriter.create<cf::BranchOp>(op.getLoc(), &op.getRegion().front());
|
||||||
|
|
||||||
for (Block &blk : op.getRegion()) {
|
for (Block &blk : op.getRegion()) {
|
||||||
if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) {
|
if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) {
|
||||||
rewriter.setInsertionPoint(yieldOp);
|
rewriter.setInsertionPoint(yieldOp);
|
||||||
rewriter.create<BranchOp>(yieldOp.getLoc(), postBlock,
|
rewriter.create<cf::BranchOp>(yieldOp.getLoc(), postBlock,
|
||||||
yieldOp.getResults());
|
yieldOp.getResults());
|
||||||
rewriter.eraseOp(yieldOp);
|
rewriter.eraseOp(yieldOp);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,7 +13,7 @@ add_mlir_dialect_library(MLIRSparseTensorPipelines
|
||||||
MLIRMemRefToLLVM
|
MLIRMemRefToLLVM
|
||||||
MLIRPass
|
MLIRPass
|
||||||
MLIRReconcileUnrealizedCasts
|
MLIRReconcileUnrealizedCasts
|
||||||
MLIRSCFToStandard
|
MLIRSCFToControlFlow
|
||||||
MLIRSparseTensor
|
MLIRSparseTensor
|
||||||
MLIRSparseTensorTransforms
|
MLIRSparseTensorTransforms
|
||||||
MLIRStandardOpsTransforms
|
MLIRStandardOpsTransforms
|
||||||
|
|
|
@ -33,7 +33,7 @@ void mlir::sparse_tensor::buildSparseCompiler(
|
||||||
pm.addNestedPass<FuncOp>(createLinalgBufferizePass());
|
pm.addNestedPass<FuncOp>(createLinalgBufferizePass());
|
||||||
pm.addNestedPass<FuncOp>(createConvertLinalgToLoopsPass());
|
pm.addNestedPass<FuncOp>(createConvertLinalgToLoopsPass());
|
||||||
pm.addNestedPass<FuncOp>(createConvertVectorToSCFPass());
|
pm.addNestedPass<FuncOp>(createConvertVectorToSCFPass());
|
||||||
pm.addPass(createLowerToCFGPass()); // --convert-scf-to-std
|
pm.addNestedPass<FuncOp>(createConvertSCFToCFPass());
|
||||||
pm.addPass(createFuncBufferizePass());
|
pm.addPass(createFuncBufferizePass());
|
||||||
pm.addPass(arith::createConstantBufferizePass());
|
pm.addPass(arith::createConstantBufferizePass());
|
||||||
pm.addNestedPass<FuncOp>(createTensorBufferizePass());
|
pm.addNestedPass<FuncOp>(createTensorBufferizePass());
|
||||||
|
|
|
@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRStandard
|
||||||
MLIRArithmetic
|
MLIRArithmetic
|
||||||
MLIRCallInterfaces
|
MLIRCallInterfaces
|
||||||
MLIRCastInterfaces
|
MLIRCastInterfaces
|
||||||
|
MLIRControlFlow
|
||||||
MLIRControlFlowInterfaces
|
MLIRControlFlowInterfaces
|
||||||
MLIRInferTypeOpInterface
|
MLIRInferTypeOpInterface
|
||||||
MLIRIR
|
MLIRIR
|
||||||
|
|
|
@ -8,9 +8,8 @@
|
||||||
|
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
|
||||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
|
||||||
#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
|
|
||||||
#include "mlir/Dialect/CommonFolders.h"
|
#include "mlir/Dialect/CommonFolders.h"
|
||||||
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||||
#include "mlir/IR/AffineExpr.h"
|
#include "mlir/IR/AffineExpr.h"
|
||||||
#include "mlir/IR/AffineMap.h"
|
#include "mlir/IR/AffineMap.h"
|
||||||
#include "mlir/IR/BlockAndValueMapping.h"
|
#include "mlir/IR/BlockAndValueMapping.h"
|
||||||
|
@ -77,7 +76,7 @@ struct StdInlinerInterface : public DialectInlinerInterface {
|
||||||
|
|
||||||
// Replace the return with a branch to the dest.
|
// Replace the return with a branch to the dest.
|
||||||
OpBuilder builder(op);
|
OpBuilder builder(op);
|
||||||
builder.create<BranchOp>(op->getLoc(), newDest, returnOp.getOperands());
|
builder.create<cf::BranchOp>(op->getLoc(), newDest, returnOp.getOperands());
|
||||||
op->erase();
|
op->erase();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -121,130 +120,6 @@ Operation *StandardOpsDialect::materializeConstant(OpBuilder &builder,
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// AssertOp
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) {
|
|
||||||
// Erase assertion if argument is constant true.
|
|
||||||
if (matchPattern(op.getArg(), m_One())) {
|
|
||||||
rewriter.eraseOp(op);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// BranchOp
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
/// Given a successor, try to collapse it to a new destination if it only
|
|
||||||
/// contains a passthrough unconditional branch. If the successor is
|
|
||||||
/// collapsable, `successor` and `successorOperands` are updated to reference
|
|
||||||
/// the new destination and values. `argStorage` is used as storage if operands
|
|
||||||
/// to the collapsed successor need to be remapped. It must outlive uses of
|
|
||||||
/// successorOperands.
|
|
||||||
static LogicalResult collapseBranch(Block *&successor,
|
|
||||||
ValueRange &successorOperands,
|
|
||||||
SmallVectorImpl<Value> &argStorage) {
|
|
||||||
// Check that the successor only contains a unconditional branch.
|
|
||||||
if (std::next(successor->begin()) != successor->end())
|
|
||||||
return failure();
|
|
||||||
// Check that the terminator is an unconditional branch.
|
|
||||||
BranchOp successorBranch = dyn_cast<BranchOp>(successor->getTerminator());
|
|
||||||
if (!successorBranch)
|
|
||||||
return failure();
|
|
||||||
// Check that the arguments are only used within the terminator.
|
|
||||||
for (BlockArgument arg : successor->getArguments()) {
|
|
||||||
for (Operation *user : arg.getUsers())
|
|
||||||
if (user != successorBranch)
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
// Don't try to collapse branches to infinite loops.
|
|
||||||
Block *successorDest = successorBranch.getDest();
|
|
||||||
if (successorDest == successor)
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
// Update the operands to the successor. If the branch parent has no
|
|
||||||
// arguments, we can use the branch operands directly.
|
|
||||||
OperandRange operands = successorBranch.getOperands();
|
|
||||||
if (successor->args_empty()) {
|
|
||||||
successor = successorDest;
|
|
||||||
successorOperands = operands;
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Otherwise, we need to remap any argument operands.
|
|
||||||
for (Value operand : operands) {
|
|
||||||
BlockArgument argOperand = operand.dyn_cast<BlockArgument>();
|
|
||||||
if (argOperand && argOperand.getOwner() == successor)
|
|
||||||
argStorage.push_back(successorOperands[argOperand.getArgNumber()]);
|
|
||||||
else
|
|
||||||
argStorage.push_back(operand);
|
|
||||||
}
|
|
||||||
successor = successorDest;
|
|
||||||
successorOperands = argStorage;
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Simplify a branch to a block that has a single predecessor. This effectively
|
|
||||||
/// merges the two blocks.
|
|
||||||
static LogicalResult
|
|
||||||
simplifyBrToBlockWithSinglePred(BranchOp op, PatternRewriter &rewriter) {
|
|
||||||
// Check that the successor block has a single predecessor.
|
|
||||||
Block *succ = op.getDest();
|
|
||||||
Block *opParent = op->getBlock();
|
|
||||||
if (succ == opParent || !llvm::hasSingleElement(succ->getPredecessors()))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
// Merge the successor into the current block and erase the branch.
|
|
||||||
rewriter.mergeBlocks(succ, opParent, op.getOperands());
|
|
||||||
rewriter.eraseOp(op);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
/// br ^bb1
|
|
||||||
/// ^bb1
|
|
||||||
/// br ^bbN(...)
|
|
||||||
///
|
|
||||||
/// -> br ^bbN(...)
|
|
||||||
///
|
|
||||||
static LogicalResult simplifyPassThroughBr(BranchOp op,
|
|
||||||
PatternRewriter &rewriter) {
|
|
||||||
Block *dest = op.getDest();
|
|
||||||
ValueRange destOperands = op.getOperands();
|
|
||||||
SmallVector<Value, 4> destOperandStorage;
|
|
||||||
|
|
||||||
// Try to collapse the successor if it points somewhere other than this
|
|
||||||
// block.
|
|
||||||
if (dest == op->getBlock() ||
|
|
||||||
failed(collapseBranch(dest, destOperands, destOperandStorage)))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
// Create a new branch with the collapsed successor.
|
|
||||||
rewriter.replaceOpWithNewOp<BranchOp>(op, dest, destOperands);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult BranchOp::canonicalize(BranchOp op, PatternRewriter &rewriter) {
|
|
||||||
return success(succeeded(simplifyBrToBlockWithSinglePred(op, rewriter)) ||
|
|
||||||
succeeded(simplifyPassThroughBr(op, rewriter)));
|
|
||||||
}
|
|
||||||
|
|
||||||
void BranchOp::setDest(Block *block) { return setSuccessor(block); }
|
|
||||||
|
|
||||||
void BranchOp::eraseOperand(unsigned index) { (*this)->eraseOperand(index); }
|
|
||||||
|
|
||||||
Optional<MutableOperandRange>
|
|
||||||
BranchOp::getMutableSuccessorOperands(unsigned index) {
|
|
||||||
assert(index == 0 && "invalid successor index");
|
|
||||||
return getDestOperandsMutable();
|
|
||||||
}
|
|
||||||
|
|
||||||
Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) {
|
|
||||||
return getDest();
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// CallOp
|
// CallOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -307,260 +182,6 @@ LogicalResult CallIndirectOp::canonicalize(CallIndirectOp indirectCall,
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// CondBranchOp
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
/// cond_br true, ^bb1, ^bb2
|
|
||||||
/// -> br ^bb1
|
|
||||||
/// cond_br false, ^bb1, ^bb2
|
|
||||||
/// -> br ^bb2
|
|
||||||
///
|
|
||||||
struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
|
|
||||||
using OpRewritePattern<CondBranchOp>::OpRewritePattern;
|
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(CondBranchOp condbr,
|
|
||||||
PatternRewriter &rewriter) const override {
|
|
||||||
if (matchPattern(condbr.getCondition(), m_NonZero())) {
|
|
||||||
// True branch taken.
|
|
||||||
rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
|
|
||||||
condbr.getTrueOperands());
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
if (matchPattern(condbr.getCondition(), m_Zero())) {
|
|
||||||
// False branch taken.
|
|
||||||
rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
|
|
||||||
condbr.getFalseOperands());
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/// cond_br %cond, ^bb1, ^bb2
|
|
||||||
/// ^bb1
|
|
||||||
/// br ^bbN(...)
|
|
||||||
/// ^bb2
|
|
||||||
/// br ^bbK(...)
|
|
||||||
///
|
|
||||||
/// -> cond_br %cond, ^bbN(...), ^bbK(...)
|
|
||||||
///
|
|
||||||
struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> {
|
|
||||||
using OpRewritePattern<CondBranchOp>::OpRewritePattern;
|
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(CondBranchOp condbr,
|
|
||||||
PatternRewriter &rewriter) const override {
|
|
||||||
Block *trueDest = condbr.getTrueDest(), *falseDest = condbr.getFalseDest();
|
|
||||||
ValueRange trueDestOperands = condbr.getTrueOperands();
|
|
||||||
ValueRange falseDestOperands = condbr.getFalseOperands();
|
|
||||||
SmallVector<Value, 4> trueDestOperandStorage, falseDestOperandStorage;
|
|
||||||
|
|
||||||
// Try to collapse one of the current successors.
|
|
||||||
LogicalResult collapsedTrue =
|
|
||||||
collapseBranch(trueDest, trueDestOperands, trueDestOperandStorage);
|
|
||||||
LogicalResult collapsedFalse =
|
|
||||||
collapseBranch(falseDest, falseDestOperands, falseDestOperandStorage);
|
|
||||||
if (failed(collapsedTrue) && failed(collapsedFalse))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
// Create a new branch with the collapsed successors.
|
|
||||||
rewriter.replaceOpWithNewOp<CondBranchOp>(condbr, condbr.getCondition(),
|
|
||||||
trueDest, trueDestOperands,
|
|
||||||
falseDest, falseDestOperands);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/// cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N)
|
|
||||||
/// -> br ^bb1(A, ..., N)
|
|
||||||
///
|
|
||||||
/// cond_br %cond, ^bb1(A), ^bb1(B)
|
|
||||||
/// -> %select = arith.select %cond, A, B
|
|
||||||
/// br ^bb1(%select)
|
|
||||||
///
|
|
||||||
struct SimplifyCondBranchIdenticalSuccessors
|
|
||||||
: public OpRewritePattern<CondBranchOp> {
|
|
||||||
using OpRewritePattern<CondBranchOp>::OpRewritePattern;
|
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(CondBranchOp condbr,
|
|
||||||
PatternRewriter &rewriter) const override {
|
|
||||||
// Check that the true and false destinations are the same and have the same
|
|
||||||
// operands.
|
|
||||||
Block *trueDest = condbr.getTrueDest();
|
|
||||||
if (trueDest != condbr.getFalseDest())
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
// If all of the operands match, no selects need to be generated.
|
|
||||||
OperandRange trueOperands = condbr.getTrueOperands();
|
|
||||||
OperandRange falseOperands = condbr.getFalseOperands();
|
|
||||||
if (trueOperands == falseOperands) {
|
|
||||||
rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, trueOperands);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Otherwise, if the current block is the only predecessor insert selects
|
|
||||||
// for any mismatched branch operands.
|
|
||||||
if (trueDest->getUniquePredecessor() != condbr->getBlock())
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
// Generate a select for any operands that differ between the two.
|
|
||||||
SmallVector<Value, 8> mergedOperands;
|
|
||||||
mergedOperands.reserve(trueOperands.size());
|
|
||||||
Value condition = condbr.getCondition();
|
|
||||||
for (auto it : llvm::zip(trueOperands, falseOperands)) {
|
|
||||||
if (std::get<0>(it) == std::get<1>(it))
|
|
||||||
mergedOperands.push_back(std::get<0>(it));
|
|
||||||
else
|
|
||||||
mergedOperands.push_back(rewriter.create<arith::SelectOp>(
|
|
||||||
condbr.getLoc(), condition, std::get<0>(it), std::get<1>(it)));
|
|
||||||
}
|
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, mergedOperands);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/// ...
|
|
||||||
/// cond_br %cond, ^bb1(...), ^bb2(...)
|
|
||||||
/// ...
|
|
||||||
/// ^bb1: // has single predecessor
|
|
||||||
/// ...
|
|
||||||
/// cond_br %cond, ^bb3(...), ^bb4(...)
|
|
||||||
///
|
|
||||||
/// ->
|
|
||||||
///
|
|
||||||
/// ...
|
|
||||||
/// cond_br %cond, ^bb1(...), ^bb2(...)
|
|
||||||
/// ...
|
|
||||||
/// ^bb1: // has single predecessor
|
|
||||||
/// ...
|
|
||||||
/// br ^bb3(...)
|
|
||||||
///
|
|
||||||
struct SimplifyCondBranchFromCondBranchOnSameCondition
|
|
||||||
: public OpRewritePattern<CondBranchOp> {
|
|
||||||
using OpRewritePattern<CondBranchOp>::OpRewritePattern;
|
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(CondBranchOp condbr,
|
|
||||||
PatternRewriter &rewriter) const override {
|
|
||||||
// Check that we have a single distinct predecessor.
|
|
||||||
Block *currentBlock = condbr->getBlock();
|
|
||||||
Block *predecessor = currentBlock->getSinglePredecessor();
|
|
||||||
if (!predecessor)
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
// Check that the predecessor terminates with a conditional branch to this
|
|
||||||
// block and that it branches on the same condition.
|
|
||||||
auto predBranch = dyn_cast<CondBranchOp>(predecessor->getTerminator());
|
|
||||||
if (!predBranch || condbr.getCondition() != predBranch.getCondition())
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
// Fold this branch to an unconditional branch.
|
|
||||||
if (currentBlock == predBranch.getTrueDest())
|
|
||||||
rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
|
|
||||||
condbr.getTrueDestOperands());
|
|
||||||
else
|
|
||||||
rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
|
|
||||||
condbr.getFalseDestOperands());
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/// cond_br %arg0, ^trueB, ^falseB
|
|
||||||
///
|
|
||||||
/// ^trueB:
|
|
||||||
/// "test.consumer1"(%arg0) : (i1) -> ()
|
|
||||||
/// ...
|
|
||||||
///
|
|
||||||
/// ^falseB:
|
|
||||||
/// "test.consumer2"(%arg0) : (i1) -> ()
|
|
||||||
/// ...
|
|
||||||
///
|
|
||||||
/// ->
|
|
||||||
///
|
|
||||||
/// cond_br %arg0, ^trueB, ^falseB
|
|
||||||
/// ^trueB:
|
|
||||||
/// "test.consumer1"(%true) : (i1) -> ()
|
|
||||||
/// ...
|
|
||||||
///
|
|
||||||
/// ^falseB:
|
|
||||||
/// "test.consumer2"(%false) : (i1) -> ()
|
|
||||||
/// ...
|
|
||||||
struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> {
|
|
||||||
using OpRewritePattern<CondBranchOp>::OpRewritePattern;
|
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(CondBranchOp condbr,
|
|
||||||
PatternRewriter &rewriter) const override {
|
|
||||||
// Check that we have a single distinct predecessor.
|
|
||||||
bool replaced = false;
|
|
||||||
Type ty = rewriter.getI1Type();
|
|
||||||
|
|
||||||
// These variables serve to prevent creating duplicate constants
|
|
||||||
// and hold constant true or false values.
|
|
||||||
Value constantTrue = nullptr;
|
|
||||||
Value constantFalse = nullptr;
|
|
||||||
|
|
||||||
// TODO These checks can be expanded to encompas any use with only
|
|
||||||
// either the true of false edge as a predecessor. For now, we fall
|
|
||||||
// back to checking the single predecessor is given by the true/fasle
|
|
||||||
// destination, thereby ensuring that only that edge can reach the
|
|
||||||
// op.
|
|
||||||
if (condbr.getTrueDest()->getSinglePredecessor()) {
|
|
||||||
for (OpOperand &use :
|
|
||||||
llvm::make_early_inc_range(condbr.getCondition().getUses())) {
|
|
||||||
if (use.getOwner()->getBlock() == condbr.getTrueDest()) {
|
|
||||||
replaced = true;
|
|
||||||
|
|
||||||
if (!constantTrue)
|
|
||||||
constantTrue = rewriter.create<arith::ConstantOp>(
|
|
||||||
condbr.getLoc(), ty, rewriter.getBoolAttr(true));
|
|
||||||
|
|
||||||
rewriter.updateRootInPlace(use.getOwner(),
|
|
||||||
[&] { use.set(constantTrue); });
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (condbr.getFalseDest()->getSinglePredecessor()) {
|
|
||||||
for (OpOperand &use :
|
|
||||||
llvm::make_early_inc_range(condbr.getCondition().getUses())) {
|
|
||||||
if (use.getOwner()->getBlock() == condbr.getFalseDest()) {
|
|
||||||
replaced = true;
|
|
||||||
|
|
||||||
if (!constantFalse)
|
|
||||||
constantFalse = rewriter.create<arith::ConstantOp>(
|
|
||||||
condbr.getLoc(), ty, rewriter.getBoolAttr(false));
|
|
||||||
|
|
||||||
rewriter.updateRootInPlace(use.getOwner(),
|
|
||||||
[&] { use.set(constantFalse); });
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return success(replaced);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|
||||||
MLIRContext *context) {
|
|
||||||
results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
|
|
||||||
SimplifyCondBranchIdenticalSuccessors,
|
|
||||||
SimplifyCondBranchFromCondBranchOnSameCondition,
|
|
||||||
CondBranchTruthPropagation>(context);
|
|
||||||
}
|
|
||||||
|
|
||||||
Optional<MutableOperandRange>
|
|
||||||
CondBranchOp::getMutableSuccessorOperands(unsigned index) {
|
|
||||||
assert(index < getNumSuccessors() && "invalid successor index");
|
|
||||||
return index == trueIndex ? getTrueDestOperandsMutable()
|
|
||||||
: getFalseDestOperandsMutable();
|
|
||||||
}
|
|
||||||
|
|
||||||
Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
|
|
||||||
if (IntegerAttr condAttr = operands.front().dyn_cast_or_null<IntegerAttr>())
|
|
||||||
return condAttr.getValue().isOneValue() ? getTrueDest() : getFalseDest();
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// ConstantOp
|
// ConstantOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -621,439 +242,6 @@ LogicalResult ReturnOp::verify() {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// SwitchOp
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
|
|
||||||
Block *defaultDestination, ValueRange defaultOperands,
|
|
||||||
DenseIntElementsAttr caseValues,
|
|
||||||
BlockRange caseDestinations,
|
|
||||||
ArrayRef<ValueRange> caseOperands) {
|
|
||||||
build(builder, result, value, defaultOperands, caseOperands, caseValues,
|
|
||||||
defaultDestination, caseDestinations);
|
|
||||||
}
|
|
||||||
|
|
||||||
void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
|
|
||||||
Block *defaultDestination, ValueRange defaultOperands,
|
|
||||||
ArrayRef<APInt> caseValues, BlockRange caseDestinations,
|
|
||||||
ArrayRef<ValueRange> caseOperands) {
|
|
||||||
DenseIntElementsAttr caseValuesAttr;
|
|
||||||
if (!caseValues.empty()) {
|
|
||||||
ShapedType caseValueType = VectorType::get(
|
|
||||||
static_cast<int64_t>(caseValues.size()), value.getType());
|
|
||||||
caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
|
|
||||||
}
|
|
||||||
build(builder, result, value, defaultDestination, defaultOperands,
|
|
||||||
caseValuesAttr, caseDestinations, caseOperands);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// <cases> ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)?
|
|
||||||
/// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )*
|
|
||||||
static ParseResult parseSwitchOpCases(
|
|
||||||
OpAsmParser &parser, Type &flagType, Block *&defaultDestination,
|
|
||||||
SmallVectorImpl<OpAsmParser::OperandType> &defaultOperands,
|
|
||||||
SmallVectorImpl<Type> &defaultOperandTypes,
|
|
||||||
DenseIntElementsAttr &caseValues,
|
|
||||||
SmallVectorImpl<Block *> &caseDestinations,
|
|
||||||
SmallVectorImpl<SmallVector<OpAsmParser::OperandType>> &caseOperands,
|
|
||||||
SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) {
|
|
||||||
if (parser.parseKeyword("default") || parser.parseColon() ||
|
|
||||||
parser.parseSuccessor(defaultDestination))
|
|
||||||
return failure();
|
|
||||||
if (succeeded(parser.parseOptionalLParen())) {
|
|
||||||
if (parser.parseRegionArgumentList(defaultOperands) ||
|
|
||||||
parser.parseColonTypeList(defaultOperandTypes) || parser.parseRParen())
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<APInt> values;
|
|
||||||
unsigned bitWidth = flagType.getIntOrFloatBitWidth();
|
|
||||||
while (succeeded(parser.parseOptionalComma())) {
|
|
||||||
int64_t value = 0;
|
|
||||||
if (failed(parser.parseInteger(value)))
|
|
||||||
return failure();
|
|
||||||
values.push_back(APInt(bitWidth, value));
|
|
||||||
|
|
||||||
Block *destination;
|
|
||||||
SmallVector<OpAsmParser::OperandType> operands;
|
|
||||||
SmallVector<Type> operandTypes;
|
|
||||||
if (failed(parser.parseColon()) ||
|
|
||||||
failed(parser.parseSuccessor(destination)))
|
|
||||||
return failure();
|
|
||||||
if (succeeded(parser.parseOptionalLParen())) {
|
|
||||||
if (failed(parser.parseRegionArgumentList(operands)) ||
|
|
||||||
failed(parser.parseColonTypeList(operandTypes)) ||
|
|
||||||
failed(parser.parseRParen()))
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
caseDestinations.push_back(destination);
|
|
||||||
caseOperands.emplace_back(operands);
|
|
||||||
caseOperandTypes.emplace_back(operandTypes);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!values.empty()) {
|
|
||||||
ShapedType caseValueType =
|
|
||||||
VectorType::get(static_cast<int64_t>(values.size()), flagType);
|
|
||||||
caseValues = DenseIntElementsAttr::get(caseValueType, values);
|
|
||||||
}
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
static void printSwitchOpCases(
|
|
||||||
OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination,
|
|
||||||
OperandRange defaultOperands, TypeRange defaultOperandTypes,
|
|
||||||
DenseIntElementsAttr caseValues, SuccessorRange caseDestinations,
|
|
||||||
OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes) {
|
|
||||||
p << " default: ";
|
|
||||||
p.printSuccessorAndUseList(defaultDestination, defaultOperands);
|
|
||||||
|
|
||||||
if (!caseValues)
|
|
||||||
return;
|
|
||||||
|
|
||||||
for (const auto &it : llvm::enumerate(caseValues.getValues<APInt>())) {
|
|
||||||
p << ',';
|
|
||||||
p.printNewline();
|
|
||||||
p << " ";
|
|
||||||
p << it.value().getLimitedValue();
|
|
||||||
p << ": ";
|
|
||||||
p.printSuccessorAndUseList(caseDestinations[it.index()],
|
|
||||||
caseOperands[it.index()]);
|
|
||||||
}
|
|
||||||
p.printNewline();
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult SwitchOp::verify() {
|
|
||||||
auto caseValues = getCaseValues();
|
|
||||||
auto caseDestinations = getCaseDestinations();
|
|
||||||
|
|
||||||
if (!caseValues && caseDestinations.empty())
|
|
||||||
return success();
|
|
||||||
|
|
||||||
Type flagType = getFlag().getType();
|
|
||||||
Type caseValueType = caseValues->getType().getElementType();
|
|
||||||
if (caseValueType != flagType)
|
|
||||||
return emitOpError() << "'flag' type (" << flagType
|
|
||||||
<< ") should match case value type (" << caseValueType
|
|
||||||
<< ")";
|
|
||||||
|
|
||||||
if (caseValues &&
|
|
||||||
caseValues->size() != static_cast<int64_t>(caseDestinations.size()))
|
|
||||||
return emitOpError() << "number of case values (" << caseValues->size()
|
|
||||||
<< ") should match number of "
|
|
||||||
"case destinations ("
|
|
||||||
<< caseDestinations.size() << ")";
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
Optional<MutableOperandRange>
|
|
||||||
SwitchOp::getMutableSuccessorOperands(unsigned index) {
|
|
||||||
assert(index < getNumSuccessors() && "invalid successor index");
|
|
||||||
return index == 0 ? getDefaultOperandsMutable()
|
|
||||||
: getCaseOperandsMutable(index - 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
|
|
||||||
Optional<DenseIntElementsAttr> caseValues = getCaseValues();
|
|
||||||
|
|
||||||
if (!caseValues)
|
|
||||||
return getDefaultDestination();
|
|
||||||
|
|
||||||
SuccessorRange caseDests = getCaseDestinations();
|
|
||||||
if (auto value = operands.front().dyn_cast_or_null<IntegerAttr>()) {
|
|
||||||
for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>()))
|
|
||||||
if (it.value() == value.getValue())
|
|
||||||
return caseDests[it.index()];
|
|
||||||
return getDefaultDestination();
|
|
||||||
}
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// switch %flag : i32, [
|
|
||||||
/// default: ^bb1
|
|
||||||
/// ]
|
|
||||||
/// -> br ^bb1
|
|
||||||
static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op,
|
|
||||||
PatternRewriter &rewriter) {
|
|
||||||
if (!op.getCaseDestinations().empty())
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
|
|
||||||
op.getDefaultOperands());
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
/// switch %flag : i32, [
|
|
||||||
/// default: ^bb1,
|
|
||||||
/// 42: ^bb1,
|
|
||||||
/// 43: ^bb2
|
|
||||||
/// ]
|
|
||||||
/// ->
|
|
||||||
/// switch %flag : i32, [
|
|
||||||
/// default: ^bb1,
|
|
||||||
/// 43: ^bb2
|
|
||||||
/// ]
|
|
||||||
static LogicalResult
|
|
||||||
dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter) {
|
|
||||||
SmallVector<Block *> newCaseDestinations;
|
|
||||||
SmallVector<ValueRange> newCaseOperands;
|
|
||||||
SmallVector<APInt> newCaseValues;
|
|
||||||
bool requiresChange = false;
|
|
||||||
auto caseValues = op.getCaseValues();
|
|
||||||
auto caseDests = op.getCaseDestinations();
|
|
||||||
|
|
||||||
for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
|
|
||||||
if (caseDests[it.index()] == op.getDefaultDestination() &&
|
|
||||||
op.getCaseOperands(it.index()) == op.getDefaultOperands()) {
|
|
||||||
requiresChange = true;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
newCaseDestinations.push_back(caseDests[it.index()]);
|
|
||||||
newCaseOperands.push_back(op.getCaseOperands(it.index()));
|
|
||||||
newCaseValues.push_back(it.value());
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!requiresChange)
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<SwitchOp>(
|
|
||||||
op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
|
|
||||||
newCaseValues, newCaseDestinations, newCaseOperands);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Helper for folding a switch with a constant value.
|
|
||||||
/// switch %c_42 : i32, [
|
|
||||||
/// default: ^bb1 ,
|
|
||||||
/// 42: ^bb2,
|
|
||||||
/// 43: ^bb3
|
|
||||||
/// ]
|
|
||||||
/// -> br ^bb2
|
|
||||||
static void foldSwitch(SwitchOp op, PatternRewriter &rewriter,
|
|
||||||
const APInt &caseValue) {
|
|
||||||
auto caseValues = op.getCaseValues();
|
|
||||||
for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
|
|
||||||
if (it.value() == caseValue) {
|
|
||||||
rewriter.replaceOpWithNewOp<BranchOp>(
|
|
||||||
op, op.getCaseDestinations()[it.index()],
|
|
||||||
op.getCaseOperands(it.index()));
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
|
|
||||||
op.getDefaultOperands());
|
|
||||||
}
|
|
||||||
|
|
||||||
/// switch %c_42 : i32, [
|
|
||||||
/// default: ^bb1,
|
|
||||||
/// 42: ^bb2,
|
|
||||||
/// 43: ^bb3
|
|
||||||
/// ]
|
|
||||||
/// -> br ^bb2
|
|
||||||
static LogicalResult simplifyConstSwitchValue(SwitchOp op,
|
|
||||||
PatternRewriter &rewriter) {
|
|
||||||
APInt caseValue;
|
|
||||||
if (!matchPattern(op.getFlag(), m_ConstantInt(&caseValue)))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
foldSwitch(op, rewriter, caseValue);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
/// switch %c_42 : i32, [
|
|
||||||
/// default: ^bb1,
|
|
||||||
/// 42: ^bb2,
|
|
||||||
/// ]
|
|
||||||
/// ^bb2:
|
|
||||||
/// br ^bb3
|
|
||||||
/// ->
|
|
||||||
/// switch %c_42 : i32, [
|
|
||||||
/// default: ^bb1,
|
|
||||||
/// 42: ^bb3,
|
|
||||||
/// ]
|
|
||||||
static LogicalResult simplifyPassThroughSwitch(SwitchOp op,
|
|
||||||
PatternRewriter &rewriter) {
|
|
||||||
SmallVector<Block *> newCaseDests;
|
|
||||||
SmallVector<ValueRange> newCaseOperands;
|
|
||||||
SmallVector<SmallVector<Value>> argStorage;
|
|
||||||
auto caseValues = op.getCaseValues();
|
|
||||||
auto caseDests = op.getCaseDestinations();
|
|
||||||
bool requiresChange = false;
|
|
||||||
for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
|
|
||||||
Block *caseDest = caseDests[i];
|
|
||||||
ValueRange caseOperands = op.getCaseOperands(i);
|
|
||||||
argStorage.emplace_back();
|
|
||||||
if (succeeded(collapseBranch(caseDest, caseOperands, argStorage.back())))
|
|
||||||
requiresChange = true;
|
|
||||||
|
|
||||||
newCaseDests.push_back(caseDest);
|
|
||||||
newCaseOperands.push_back(caseOperands);
|
|
||||||
}
|
|
||||||
|
|
||||||
Block *defaultDest = op.getDefaultDestination();
|
|
||||||
ValueRange defaultOperands = op.getDefaultOperands();
|
|
||||||
argStorage.emplace_back();
|
|
||||||
|
|
||||||
if (succeeded(
|
|
||||||
collapseBranch(defaultDest, defaultOperands, argStorage.back())))
|
|
||||||
requiresChange = true;
|
|
||||||
|
|
||||||
if (!requiresChange)
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<SwitchOp>(op, op.getFlag(), defaultDest,
|
|
||||||
defaultOperands, caseValues.getValue(),
|
|
||||||
newCaseDests, newCaseOperands);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
/// switch %flag : i32, [
|
|
||||||
/// default: ^bb1,
|
|
||||||
/// 42: ^bb2,
|
|
||||||
/// ]
|
|
||||||
/// ^bb2:
|
|
||||||
/// switch %flag : i32, [
|
|
||||||
/// default: ^bb3,
|
|
||||||
/// 42: ^bb4
|
|
||||||
/// ]
|
|
||||||
/// ->
|
|
||||||
/// switch %flag : i32, [
|
|
||||||
/// default: ^bb1,
|
|
||||||
/// 42: ^bb2,
|
|
||||||
/// ]
|
|
||||||
/// ^bb2:
|
|
||||||
/// br ^bb4
|
|
||||||
///
|
|
||||||
/// and
|
|
||||||
///
|
|
||||||
/// switch %flag : i32, [
|
|
||||||
/// default: ^bb1,
|
|
||||||
/// 42: ^bb2,
|
|
||||||
/// ]
|
|
||||||
/// ^bb2:
|
|
||||||
/// switch %flag : i32, [
|
|
||||||
/// default: ^bb3,
|
|
||||||
/// 43: ^bb4
|
|
||||||
/// ]
|
|
||||||
/// ->
|
|
||||||
/// switch %flag : i32, [
|
|
||||||
/// default: ^bb1,
|
|
||||||
/// 42: ^bb2,
|
|
||||||
/// ]
|
|
||||||
/// ^bb2:
|
|
||||||
/// br ^bb3
|
|
||||||
static LogicalResult
|
|
||||||
simplifySwitchFromSwitchOnSameCondition(SwitchOp op,
|
|
||||||
PatternRewriter &rewriter) {
|
|
||||||
// Check that we have a single distinct predecessor.
|
|
||||||
Block *currentBlock = op->getBlock();
|
|
||||||
Block *predecessor = currentBlock->getSinglePredecessor();
|
|
||||||
if (!predecessor)
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
// Check that the predecessor terminates with a switch branch to this block
|
|
||||||
// and that it branches on the same condition and that this branch isn't the
|
|
||||||
// default destination.
|
|
||||||
auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
|
|
||||||
if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
|
|
||||||
predSwitch.getDefaultDestination() == currentBlock)
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
// Fold this switch to an unconditional branch.
|
|
||||||
SuccessorRange predDests = predSwitch.getCaseDestinations();
|
|
||||||
auto it = llvm::find(predDests, currentBlock);
|
|
||||||
if (it != predDests.end()) {
|
|
||||||
Optional<DenseIntElementsAttr> predCaseValues = predSwitch.getCaseValues();
|
|
||||||
foldSwitch(op, rewriter,
|
|
||||||
predCaseValues->getValues<APInt>()[it - predDests.begin()]);
|
|
||||||
} else {
|
|
||||||
rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
|
|
||||||
op.getDefaultOperands());
|
|
||||||
}
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
/// switch %flag : i32, [
|
|
||||||
/// default: ^bb1,
|
|
||||||
/// 42: ^bb2
|
|
||||||
/// ]
|
|
||||||
/// ^bb1:
|
|
||||||
/// switch %flag : i32, [
|
|
||||||
/// default: ^bb3,
|
|
||||||
/// 42: ^bb4,
|
|
||||||
/// 43: ^bb5
|
|
||||||
/// ]
|
|
||||||
/// ->
|
|
||||||
/// switch %flag : i32, [
|
|
||||||
/// default: ^bb1,
|
|
||||||
/// 42: ^bb2,
|
|
||||||
/// ]
|
|
||||||
/// ^bb1:
|
|
||||||
/// switch %flag : i32, [
|
|
||||||
/// default: ^bb3,
|
|
||||||
/// 43: ^bb5
|
|
||||||
/// ]
|
|
||||||
static LogicalResult
|
|
||||||
simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op,
|
|
||||||
PatternRewriter &rewriter) {
|
|
||||||
// Check that we have a single distinct predecessor.
|
|
||||||
Block *currentBlock = op->getBlock();
|
|
||||||
Block *predecessor = currentBlock->getSinglePredecessor();
|
|
||||||
if (!predecessor)
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
// Check that the predecessor terminates with a switch branch to this block
|
|
||||||
// and that it branches on the same condition and that this branch is the
|
|
||||||
// default destination.
|
|
||||||
auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
|
|
||||||
if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
|
|
||||||
predSwitch.getDefaultDestination() != currentBlock)
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
// Delete case values that are not possible here.
|
|
||||||
DenseSet<APInt> caseValuesToRemove;
|
|
||||||
auto predDests = predSwitch.getCaseDestinations();
|
|
||||||
auto predCaseValues = predSwitch.getCaseValues();
|
|
||||||
for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i)
|
|
||||||
if (currentBlock != predDests[i])
|
|
||||||
caseValuesToRemove.insert(predCaseValues->getValues<APInt>()[i]);
|
|
||||||
|
|
||||||
SmallVector<Block *> newCaseDestinations;
|
|
||||||
SmallVector<ValueRange> newCaseOperands;
|
|
||||||
SmallVector<APInt> newCaseValues;
|
|
||||||
bool requiresChange = false;
|
|
||||||
|
|
||||||
auto caseValues = op.getCaseValues();
|
|
||||||
auto caseDests = op.getCaseDestinations();
|
|
||||||
for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
|
|
||||||
if (caseValuesToRemove.contains(it.value())) {
|
|
||||||
requiresChange = true;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
newCaseDestinations.push_back(caseDests[it.index()]);
|
|
||||||
newCaseOperands.push_back(op.getCaseOperands(it.index()));
|
|
||||||
newCaseValues.push_back(it.value());
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!requiresChange)
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<SwitchOp>(
|
|
||||||
op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
|
|
||||||
newCaseValues, newCaseDestinations, newCaseOperands);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
void SwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|
||||||
MLIRContext *context) {
|
|
||||||
results.add(&simplifySwitchWithOnlyDefault)
|
|
||||||
.add(&dropSwitchCasesThatMatchDefault)
|
|
||||||
.add(&simplifyConstSwitchValue)
|
|
||||||
.add(&simplifyPassThroughSwitch)
|
|
||||||
.add(&simplifySwitchFromSwitchOnSameCondition)
|
|
||||||
.add(&simplifySwitchFromDefaultSwitchOnSameCondition);
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// TableGen'd op method definitions
|
// TableGen'd op method definitions
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -7,6 +7,7 @@
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||||
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
|
||||||
#include "mlir/Dialect/EmitC/IR/EmitC.h"
|
#include "mlir/Dialect/EmitC/IR/EmitC.h"
|
||||||
#include "mlir/Dialect/Math/IR/Math.h"
|
#include "mlir/Dialect/Math/IR/Math.h"
|
||||||
#include "mlir/Dialect/SCF/SCF.h"
|
#include "mlir/Dialect/SCF/SCF.h"
|
||||||
|
@ -41,6 +42,7 @@ void registerToCppTranslation() {
|
||||||
[](DialectRegistry ®istry) {
|
[](DialectRegistry ®istry) {
|
||||||
// clang-format off
|
// clang-format off
|
||||||
registry.insert<arith::ArithmeticDialect,
|
registry.insert<arith::ArithmeticDialect,
|
||||||
|
cf::ControlFlowDialect,
|
||||||
emitc::EmitCDialect,
|
emitc::EmitCDialect,
|
||||||
math::MathDialect,
|
math::MathDialect,
|
||||||
StandardOpsDialect,
|
StandardOpsDialect,
|
||||||
|
|
|
@ -6,8 +6,7 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include <utility>
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||||
|
|
||||||
#include "mlir/Dialect/EmitC/IR/EmitC.h"
|
#include "mlir/Dialect/EmitC/IR/EmitC.h"
|
||||||
#include "mlir/Dialect/SCF/SCF.h"
|
#include "mlir/Dialect/SCF/SCF.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
@ -23,6 +22,7 @@
|
||||||
#include "llvm/ADT/TypeSwitch.h"
|
#include "llvm/ADT/TypeSwitch.h"
|
||||||
#include "llvm/Support/Debug.h"
|
#include "llvm/Support/Debug.h"
|
||||||
#include "llvm/Support/FormatVariadic.h"
|
#include "llvm/Support/FormatVariadic.h"
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
#define DEBUG_TYPE "translate-to-cpp"
|
#define DEBUG_TYPE "translate-to-cpp"
|
||||||
|
|
||||||
|
@ -237,7 +237,8 @@ static LogicalResult printOperation(CppEmitter &emitter,
|
||||||
return printConstantOp(emitter, operation, value);
|
return printConstantOp(emitter, operation, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
static LogicalResult printOperation(CppEmitter &emitter, BranchOp branchOp) {
|
static LogicalResult printOperation(CppEmitter &emitter,
|
||||||
|
cf::BranchOp branchOp) {
|
||||||
raw_ostream &os = emitter.ostream();
|
raw_ostream &os = emitter.ostream();
|
||||||
Block &successor = *branchOp.getSuccessor();
|
Block &successor = *branchOp.getSuccessor();
|
||||||
|
|
||||||
|
@ -257,7 +258,7 @@ static LogicalResult printOperation(CppEmitter &emitter, BranchOp branchOp) {
|
||||||
}
|
}
|
||||||
|
|
||||||
static LogicalResult printOperation(CppEmitter &emitter,
|
static LogicalResult printOperation(CppEmitter &emitter,
|
||||||
CondBranchOp condBranchOp) {
|
cf::CondBranchOp condBranchOp) {
|
||||||
raw_indented_ostream &os = emitter.ostream();
|
raw_indented_ostream &os = emitter.ostream();
|
||||||
Block &trueSuccessor = *condBranchOp.getTrueDest();
|
Block &trueSuccessor = *condBranchOp.getTrueDest();
|
||||||
Block &falseSuccessor = *condBranchOp.getFalseDest();
|
Block &falseSuccessor = *condBranchOp.getFalseDest();
|
||||||
|
@ -637,11 +638,12 @@ static LogicalResult printOperation(CppEmitter &emitter, FuncOp functionOp) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
for (Operation &op : block.getOperations()) {
|
for (Operation &op : block.getOperations()) {
|
||||||
// When generating code for an scf.if or std.cond_br op no semicolon needs
|
// When generating code for an scf.if or cf.cond_br op no semicolon needs
|
||||||
// to be printed after the closing brace.
|
// to be printed after the closing brace.
|
||||||
// When generating code for an scf.for op, printing a trailing semicolon
|
// When generating code for an scf.for op, printing a trailing semicolon
|
||||||
// is handled within the printOperation function.
|
// is handled within the printOperation function.
|
||||||
bool trailingSemicolon = !isa<scf::IfOp, scf::ForOp, CondBranchOp>(op);
|
bool trailingSemicolon =
|
||||||
|
!isa<scf::IfOp, scf::ForOp, cf::CondBranchOp>(op);
|
||||||
|
|
||||||
if (failed(emitter.emitOperation(
|
if (failed(emitter.emitOperation(
|
||||||
op, /*trailingSemicolon=*/trailingSemicolon)))
|
op, /*trailingSemicolon=*/trailingSemicolon)))
|
||||||
|
@ -907,8 +909,8 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
|
||||||
.Case<scf::ForOp, scf::IfOp, scf::YieldOp>(
|
.Case<scf::ForOp, scf::IfOp, scf::YieldOp>(
|
||||||
[&](auto op) { return printOperation(*this, op); })
|
[&](auto op) { return printOperation(*this, op); })
|
||||||
// Standard ops.
|
// Standard ops.
|
||||||
.Case<BranchOp, mlir::CallOp, CondBranchOp, mlir::ConstantOp, FuncOp,
|
.Case<cf::BranchOp, mlir::CallOp, cf::CondBranchOp, mlir::ConstantOp,
|
||||||
ModuleOp, ReturnOp>(
|
FuncOp, ModuleOp, ReturnOp>(
|
||||||
[&](auto op) { return printOperation(*this, op); })
|
[&](auto op) { return printOperation(*this, op); })
|
||||||
// Arithmetic ops.
|
// Arithmetic ops.
|
||||||
.Case<arith::ConstantOp>(
|
.Case<arith::ConstantOp>(
|
||||||
|
|
|
@ -52,10 +52,10 @@ func @control_flow(%arg: memref<2xf32>, %cond: i1) attributes {test.ptr = "func"
|
||||||
%1 = memref.alloca() {test.ptr = "alloca_2"} : memref<8x64xf32>
|
%1 = memref.alloca() {test.ptr = "alloca_2"} : memref<8x64xf32>
|
||||||
%2 = memref.alloc() {test.ptr = "alloc_1"} : memref<8x64xf32>
|
%2 = memref.alloc() {test.ptr = "alloc_1"} : memref<8x64xf32>
|
||||||
|
|
||||||
cond_br %cond, ^bb1(%0 : memref<8x64xf32>), ^bb2(%0 : memref<8x64xf32>)
|
cf.cond_br %cond, ^bb1(%0 : memref<8x64xf32>), ^bb2(%0 : memref<8x64xf32>)
|
||||||
|
|
||||||
^bb1(%arg1: memref<8x64xf32>):
|
^bb1(%arg1: memref<8x64xf32>):
|
||||||
br ^bb2(%arg1 : memref<8x64xf32>)
|
cf.br ^bb2(%arg1 : memref<8x64xf32>)
|
||||||
|
|
||||||
^bb2(%arg2: memref<8x64xf32>):
|
^bb2(%arg2: memref<8x64xf32>):
|
||||||
return
|
return
|
||||||
|
@ -85,10 +85,10 @@ func @control_flow_merge(%arg: memref<2xf32>, %cond: i1) attributes {test.ptr =
|
||||||
%1 = memref.alloca() {test.ptr = "alloca_2"} : memref<8x64xf32>
|
%1 = memref.alloca() {test.ptr = "alloca_2"} : memref<8x64xf32>
|
||||||
%2 = memref.alloc() {test.ptr = "alloc_1"} : memref<8x64xf32>
|
%2 = memref.alloc() {test.ptr = "alloc_1"} : memref<8x64xf32>
|
||||||
|
|
||||||
cond_br %cond, ^bb1(%0 : memref<8x64xf32>), ^bb2(%2 : memref<8x64xf32>)
|
cf.cond_br %cond, ^bb1(%0 : memref<8x64xf32>), ^bb2(%2 : memref<8x64xf32>)
|
||||||
|
|
||||||
^bb1(%arg1: memref<8x64xf32>):
|
^bb1(%arg1: memref<8x64xf32>):
|
||||||
br ^bb2(%arg1 : memref<8x64xf32>)
|
cf.br ^bb2(%arg1 : memref<8x64xf32>)
|
||||||
|
|
||||||
^bb2(%arg2: memref<8x64xf32>):
|
^bb2(%arg2: memref<8x64xf32>):
|
||||||
return
|
return
|
||||||
|
|
|
@ -2,11 +2,11 @@
|
||||||
|
|
||||||
// CHECK-LABEL: Testing : func_condBranch
|
// CHECK-LABEL: Testing : func_condBranch
|
||||||
func @func_condBranch(%cond : i1) {
|
func @func_condBranch(%cond : i1) {
|
||||||
cond_br %cond, ^bb1, ^bb2
|
cf.cond_br %cond, ^bb1, ^bb2
|
||||||
^bb1:
|
^bb1:
|
||||||
br ^exit
|
cf.br ^exit
|
||||||
^bb2:
|
^bb2:
|
||||||
br ^exit
|
cf.br ^exit
|
||||||
^exit:
|
^exit:
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -49,14 +49,14 @@ func @func_condBranch(%cond : i1) {
|
||||||
|
|
||||||
// CHECK-LABEL: Testing : func_loop
|
// CHECK-LABEL: Testing : func_loop
|
||||||
func @func_loop(%arg0 : i32, %arg1 : i32) {
|
func @func_loop(%arg0 : i32, %arg1 : i32) {
|
||||||
br ^loopHeader(%arg0 : i32)
|
cf.br ^loopHeader(%arg0 : i32)
|
||||||
^loopHeader(%counter : i32):
|
^loopHeader(%counter : i32):
|
||||||
%lessThan = arith.cmpi slt, %counter, %arg1 : i32
|
%lessThan = arith.cmpi slt, %counter, %arg1 : i32
|
||||||
cond_br %lessThan, ^loopBody, ^exit
|
cf.cond_br %lessThan, ^loopBody, ^exit
|
||||||
^loopBody:
|
^loopBody:
|
||||||
%const0 = arith.constant 1 : i32
|
%const0 = arith.constant 1 : i32
|
||||||
%inc = arith.addi %counter, %const0 : i32
|
%inc = arith.addi %counter, %const0 : i32
|
||||||
br ^loopHeader(%inc : i32)
|
cf.br ^loopHeader(%inc : i32)
|
||||||
^exit:
|
^exit:
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -153,17 +153,17 @@ func @func_loop_nested_region(
|
||||||
%arg2 : index,
|
%arg2 : index,
|
||||||
%arg3 : index,
|
%arg3 : index,
|
||||||
%arg4 : index) {
|
%arg4 : index) {
|
||||||
br ^loopHeader(%arg0 : i32)
|
cf.br ^loopHeader(%arg0 : i32)
|
||||||
^loopHeader(%counter : i32):
|
^loopHeader(%counter : i32):
|
||||||
%lessThan = arith.cmpi slt, %counter, %arg1 : i32
|
%lessThan = arith.cmpi slt, %counter, %arg1 : i32
|
||||||
cond_br %lessThan, ^loopBody, ^exit
|
cf.cond_br %lessThan, ^loopBody, ^exit
|
||||||
^loopBody:
|
^loopBody:
|
||||||
%const0 = arith.constant 1 : i32
|
%const0 = arith.constant 1 : i32
|
||||||
%inc = arith.addi %counter, %const0 : i32
|
%inc = arith.addi %counter, %const0 : i32
|
||||||
scf.for %arg5 = %arg2 to %arg3 step %arg4 {
|
scf.for %arg5 = %arg2 to %arg3 step %arg4 {
|
||||||
scf.for %arg6 = %arg2 to %arg3 step %arg4 { }
|
scf.for %arg6 = %arg2 to %arg3 step %arg4 { }
|
||||||
}
|
}
|
||||||
br ^loopHeader(%inc : i32)
|
cf.br ^loopHeader(%inc : i32)
|
||||||
^exit:
|
^exit:
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,7 +19,7 @@ func @func_simpleBranch(%arg0: i32, %arg1 : i32) -> i32 {
|
||||||
// CHECK-NEXT: LiveOut: arg0@0 arg1@0
|
// CHECK-NEXT: LiveOut: arg0@0 arg1@0
|
||||||
// CHECK-NEXT: BeginLiveness
|
// CHECK-NEXT: BeginLiveness
|
||||||
// CHECK-NEXT: EndLiveness
|
// CHECK-NEXT: EndLiveness
|
||||||
br ^exit
|
cf.br ^exit
|
||||||
^exit:
|
^exit:
|
||||||
// CHECK: Block: 1
|
// CHECK: Block: 1
|
||||||
// CHECK-NEXT: LiveIn: arg0@0 arg1@0
|
// CHECK-NEXT: LiveIn: arg0@0 arg1@0
|
||||||
|
@ -42,17 +42,17 @@ func @func_condBranch(%cond : i1, %arg1: i32, %arg2 : i32) -> i32 {
|
||||||
// CHECK-NEXT: LiveOut: arg1@0 arg2@0
|
// CHECK-NEXT: LiveOut: arg1@0 arg2@0
|
||||||
// CHECK-NEXT: BeginLiveness
|
// CHECK-NEXT: BeginLiveness
|
||||||
// CHECK-NEXT: EndLiveness
|
// CHECK-NEXT: EndLiveness
|
||||||
cond_br %cond, ^bb1, ^bb2
|
cf.cond_br %cond, ^bb1, ^bb2
|
||||||
^bb1:
|
^bb1:
|
||||||
// CHECK: Block: 1
|
// CHECK: Block: 1
|
||||||
// CHECK-NEXT: LiveIn: arg1@0 arg2@0
|
// CHECK-NEXT: LiveIn: arg1@0 arg2@0
|
||||||
// CHECK-NEXT: LiveOut: arg1@0 arg2@0
|
// CHECK-NEXT: LiveOut: arg1@0 arg2@0
|
||||||
br ^exit
|
cf.br ^exit
|
||||||
^bb2:
|
^bb2:
|
||||||
// CHECK: Block: 2
|
// CHECK: Block: 2
|
||||||
// CHECK-NEXT: LiveIn: arg1@0 arg2@0
|
// CHECK-NEXT: LiveIn: arg1@0 arg2@0
|
||||||
// CHECK-NEXT: LiveOut: arg1@0 arg2@0
|
// CHECK-NEXT: LiveOut: arg1@0 arg2@0
|
||||||
br ^exit
|
cf.br ^exit
|
||||||
^exit:
|
^exit:
|
||||||
// CHECK: Block: 3
|
// CHECK: Block: 3
|
||||||
// CHECK-NEXT: LiveIn: arg1@0 arg2@0
|
// CHECK-NEXT: LiveIn: arg1@0 arg2@0
|
||||||
|
@ -74,7 +74,7 @@ func @func_loop(%arg0 : i32, %arg1 : i32) -> i32 {
|
||||||
// CHECK-NEXT: LiveIn:{{ *$}}
|
// CHECK-NEXT: LiveIn:{{ *$}}
|
||||||
// CHECK-NEXT: LiveOut: arg1@0
|
// CHECK-NEXT: LiveOut: arg1@0
|
||||||
%const0 = arith.constant 0 : i32
|
%const0 = arith.constant 0 : i32
|
||||||
br ^loopHeader(%const0, %arg0 : i32, i32)
|
cf.br ^loopHeader(%const0, %arg0 : i32, i32)
|
||||||
^loopHeader(%counter : i32, %i : i32):
|
^loopHeader(%counter : i32, %i : i32):
|
||||||
// CHECK: Block: 1
|
// CHECK: Block: 1
|
||||||
// CHECK-NEXT: LiveIn: arg1@0
|
// CHECK-NEXT: LiveIn: arg1@0
|
||||||
|
@ -82,10 +82,10 @@ func @func_loop(%arg0 : i32, %arg1 : i32) -> i32 {
|
||||||
// CHECK-NEXT: BeginLiveness
|
// CHECK-NEXT: BeginLiveness
|
||||||
// CHECK-NEXT: val_5
|
// CHECK-NEXT: val_5
|
||||||
// CHECK-NEXT: %2 = arith.cmpi
|
// CHECK-NEXT: %2 = arith.cmpi
|
||||||
// CHECK-NEXT: cond_br
|
// CHECK-NEXT: cf.cond_br
|
||||||
// CHECK-NEXT: EndLiveness
|
// CHECK-NEXT: EndLiveness
|
||||||
%lessThan = arith.cmpi slt, %counter, %arg1 : i32
|
%lessThan = arith.cmpi slt, %counter, %arg1 : i32
|
||||||
cond_br %lessThan, ^loopBody(%i : i32), ^exit(%i : i32)
|
cf.cond_br %lessThan, ^loopBody(%i : i32), ^exit(%i : i32)
|
||||||
^loopBody(%val : i32):
|
^loopBody(%val : i32):
|
||||||
// CHECK: Block: 2
|
// CHECK: Block: 2
|
||||||
// CHECK-NEXT: LiveIn: arg1@0 arg0@1
|
// CHECK-NEXT: LiveIn: arg1@0 arg0@1
|
||||||
|
@ -98,12 +98,12 @@ func @func_loop(%arg0 : i32, %arg1 : i32) -> i32 {
|
||||||
// CHECK-NEXT: val_8
|
// CHECK-NEXT: val_8
|
||||||
// CHECK-NEXT: %4 = arith.addi
|
// CHECK-NEXT: %4 = arith.addi
|
||||||
// CHECK-NEXT: %5 = arith.addi
|
// CHECK-NEXT: %5 = arith.addi
|
||||||
// CHECK-NEXT: br
|
// CHECK-NEXT: cf.br
|
||||||
// CHECK: EndLiveness
|
// CHECK: EndLiveness
|
||||||
%const1 = arith.constant 1 : i32
|
%const1 = arith.constant 1 : i32
|
||||||
%inc = arith.addi %val, %const1 : i32
|
%inc = arith.addi %val, %const1 : i32
|
||||||
%inc2 = arith.addi %counter, %const1 : i32
|
%inc2 = arith.addi %counter, %const1 : i32
|
||||||
br ^loopHeader(%inc, %inc2 : i32, i32)
|
cf.br ^loopHeader(%inc, %inc2 : i32, i32)
|
||||||
^exit(%sum : i32):
|
^exit(%sum : i32):
|
||||||
// CHECK: Block: 3
|
// CHECK: Block: 3
|
||||||
// CHECK-NEXT: LiveIn: arg1@0
|
// CHECK-NEXT: LiveIn: arg1@0
|
||||||
|
@ -147,14 +147,14 @@ func @func_ranges(%cond : i1, %arg1 : i32, %arg2 : i32, %arg3 : i32) -> i32 {
|
||||||
// CHECK-NEXT: val_9
|
// CHECK-NEXT: val_9
|
||||||
// CHECK-NEXT: %4 = arith.muli
|
// CHECK-NEXT: %4 = arith.muli
|
||||||
// CHECK-NEXT: %5 = arith.addi
|
// CHECK-NEXT: %5 = arith.addi
|
||||||
// CHECK-NEXT: cond_br
|
// CHECK-NEXT: cf.cond_br
|
||||||
// CHECK-NEXT: %c
|
// CHECK-NEXT: %c
|
||||||
// CHECK-NEXT: %6 = arith.muli
|
// CHECK-NEXT: %6 = arith.muli
|
||||||
// CHECK-NEXT: %7 = arith.muli
|
// CHECK-NEXT: %7 = arith.muli
|
||||||
// CHECK-NEXT: %8 = arith.addi
|
// CHECK-NEXT: %8 = arith.addi
|
||||||
// CHECK-NEXT: val_10
|
// CHECK-NEXT: val_10
|
||||||
// CHECK-NEXT: %5 = arith.addi
|
// CHECK-NEXT: %5 = arith.addi
|
||||||
// CHECK-NEXT: cond_br
|
// CHECK-NEXT: cf.cond_br
|
||||||
// CHECK-NEXT: %7
|
// CHECK-NEXT: %7
|
||||||
// CHECK: EndLiveness
|
// CHECK: EndLiveness
|
||||||
%0 = arith.addi %arg1, %arg2 : i32
|
%0 = arith.addi %arg1, %arg2 : i32
|
||||||
|
@ -164,7 +164,7 @@ func @func_ranges(%cond : i1, %arg1 : i32, %arg2 : i32, %arg3 : i32) -> i32 {
|
||||||
%3 = arith.muli %0, %1 : i32
|
%3 = arith.muli %0, %1 : i32
|
||||||
%4 = arith.muli %3, %2 : i32
|
%4 = arith.muli %3, %2 : i32
|
||||||
%5 = arith.addi %4, %const1 : i32
|
%5 = arith.addi %4, %const1 : i32
|
||||||
cond_br %cond, ^bb1, ^bb2
|
cf.cond_br %cond, ^bb1, ^bb2
|
||||||
|
|
||||||
^bb1:
|
^bb1:
|
||||||
// CHECK: Block: 1
|
// CHECK: Block: 1
|
||||||
|
@ -172,7 +172,7 @@ func @func_ranges(%cond : i1, %arg1 : i32, %arg2 : i32, %arg3 : i32) -> i32 {
|
||||||
// CHECK-NEXT: LiveOut: arg2@0
|
// CHECK-NEXT: LiveOut: arg2@0
|
||||||
%const4 = arith.constant 4 : i32
|
%const4 = arith.constant 4 : i32
|
||||||
%6 = arith.muli %4, %const4 : i32
|
%6 = arith.muli %4, %const4 : i32
|
||||||
br ^exit(%6 : i32)
|
cf.br ^exit(%6 : i32)
|
||||||
|
|
||||||
^bb2:
|
^bb2:
|
||||||
// CHECK: Block: 2
|
// CHECK: Block: 2
|
||||||
|
@ -180,7 +180,7 @@ func @func_ranges(%cond : i1, %arg1 : i32, %arg2 : i32, %arg3 : i32) -> i32 {
|
||||||
// CHECK-NEXT: LiveOut: arg2@0
|
// CHECK-NEXT: LiveOut: arg2@0
|
||||||
%7 = arith.muli %4, %5 : i32
|
%7 = arith.muli %4, %5 : i32
|
||||||
%8 = arith.addi %4, %arg2 : i32
|
%8 = arith.addi %4, %arg2 : i32
|
||||||
br ^exit(%8 : i32)
|
cf.br ^exit(%8 : i32)
|
||||||
|
|
||||||
^exit(%sum : i32):
|
^exit(%sum : i32):
|
||||||
// CHECK: Block: 3
|
// CHECK: Block: 3
|
||||||
|
@ -284,7 +284,7 @@ func @nested_region3(
|
||||||
// CHECK-NEXT: %0 = arith.addi
|
// CHECK-NEXT: %0 = arith.addi
|
||||||
// CHECK-NEXT: %1 = arith.addi
|
// CHECK-NEXT: %1 = arith.addi
|
||||||
// CHECK-NEXT: scf.for
|
// CHECK-NEXT: scf.for
|
||||||
// CHECK: // br ^bb1
|
// CHECK: // cf.br ^bb1
|
||||||
// CHECK-NEXT: %2 = arith.addi
|
// CHECK-NEXT: %2 = arith.addi
|
||||||
// CHECK-NEXT: scf.for
|
// CHECK-NEXT: scf.for
|
||||||
// CHECK: // %2 = arith.addi
|
// CHECK: // %2 = arith.addi
|
||||||
|
@ -301,7 +301,7 @@ func @nested_region3(
|
||||||
%2 = arith.addi %0, %arg5 : i32
|
%2 = arith.addi %0, %arg5 : i32
|
||||||
memref.store %2, %buffer[] : memref<i32>
|
memref.store %2, %buffer[] : memref<i32>
|
||||||
}
|
}
|
||||||
br ^exit
|
cf.br ^exit
|
||||||
|
|
||||||
^exit:
|
^exit:
|
||||||
// CHECK: Block: 2
|
// CHECK: Block: 2
|
||||||
|
|
|
@ -1531,10 +1531,10 @@ int registerOnlyStd() {
|
||||||
fprintf(stderr, "@registration\n");
|
fprintf(stderr, "@registration\n");
|
||||||
// CHECK-LABEL: @registration
|
// CHECK-LABEL: @registration
|
||||||
|
|
||||||
// CHECK: std.cond_br is_registered: 1
|
// CHECK: cf.cond_br is_registered: 1
|
||||||
fprintf(stderr, "std.cond_br is_registered: %d\n",
|
fprintf(stderr, "cf.cond_br is_registered: %d\n",
|
||||||
mlirContextIsRegisteredOperation(
|
mlirContextIsRegisteredOperation(
|
||||||
ctx, mlirStringRefCreateFromCString("std.cond_br")));
|
ctx, mlirStringRefCreateFromCString("cf.cond_br")));
|
||||||
|
|
||||||
// CHECK: std.not_existing_op is_registered: 0
|
// CHECK: std.not_existing_op is_registered: 0
|
||||||
fprintf(stderr, "std.not_existing_op is_registered: %d\n",
|
fprintf(stderr, "std.not_existing_op is_registered: %d\n",
|
||||||
|
|
|
@ -27,7 +27,7 @@ func @execute_no_async_args(%arg0: f32, %arg1: memref<1xf32>) {
|
||||||
// CHECK: %[[IS_ERROR:.*]] = call @mlirAsyncRuntimeIsTokenError(%[[TOKEN]])
|
// CHECK: %[[IS_ERROR:.*]] = call @mlirAsyncRuntimeIsTokenError(%[[TOKEN]])
|
||||||
// CHECK: %[[TRUE:.*]] = arith.constant true
|
// CHECK: %[[TRUE:.*]] = arith.constant true
|
||||||
// CHECK: %[[NOT_ERROR:.*]] = arith.xori %[[IS_ERROR]], %[[TRUE]] : i1
|
// CHECK: %[[NOT_ERROR:.*]] = arith.xori %[[IS_ERROR]], %[[TRUE]] : i1
|
||||||
// CHECK: assert %[[NOT_ERROR]]
|
// CHECK: cf.assert %[[NOT_ERROR]]
|
||||||
// CHECK-NEXT: return
|
// CHECK-NEXT: return
|
||||||
async.await %token : !async.token
|
async.await %token : !async.token
|
||||||
return
|
return
|
||||||
|
@ -90,7 +90,7 @@ func @nested_async_execute(%arg0: f32, %arg1: f32, %arg2: memref<1xf32>) {
|
||||||
// CHECK: %[[IS_ERROR:.*]] = call @mlirAsyncRuntimeIsTokenError(%[[TOKEN]])
|
// CHECK: %[[IS_ERROR:.*]] = call @mlirAsyncRuntimeIsTokenError(%[[TOKEN]])
|
||||||
// CHECK: %[[TRUE:.*]] = arith.constant true
|
// CHECK: %[[TRUE:.*]] = arith.constant true
|
||||||
// CHECK: %[[NOT_ERROR:.*]] = arith.xori %[[IS_ERROR]], %[[TRUE]] : i1
|
// CHECK: %[[NOT_ERROR:.*]] = arith.xori %[[IS_ERROR]], %[[TRUE]] : i1
|
||||||
// CHECK: assert %[[NOT_ERROR]]
|
// CHECK: cf.assert %[[NOT_ERROR]]
|
||||||
async.await %token0 : !async.token
|
async.await %token0 : !async.token
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,41 @@
|
||||||
|
// RUN: mlir-opt -split-input-file -convert-std-to-spirv -verify-diagnostics %s | FileCheck %s
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// cf.br, cf.cond_br
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
module attributes {
|
||||||
|
spv.target_env = #spv.target_env<#spv.vce<v1.0, [], []>, {}>
|
||||||
|
} {
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @simple_loop
|
||||||
|
func @simple_loop(index, index, index) {
|
||||||
|
^bb0(%begin : index, %end : index, %step : index):
|
||||||
|
// CHECK-NEXT: spv.Branch ^bb1
|
||||||
|
cf.br ^bb1
|
||||||
|
|
||||||
|
// CHECK-NEXT: ^bb1: // pred: ^bb0
|
||||||
|
// CHECK-NEXT: spv.Branch ^bb2({{.*}} : i32)
|
||||||
|
^bb1: // pred: ^bb0
|
||||||
|
cf.br ^bb2(%begin : index)
|
||||||
|
|
||||||
|
// CHECK: ^bb2({{.*}}: i32): // 2 preds: ^bb1, ^bb3
|
||||||
|
// CHECK-NEXT: {{.*}} = spv.SLessThan {{.*}}, {{.*}} : i32
|
||||||
|
// CHECK-NEXT: spv.BranchConditional {{.*}}, ^bb3, ^bb4
|
||||||
|
^bb2(%0: index): // 2 preds: ^bb1, ^bb3
|
||||||
|
%1 = arith.cmpi slt, %0, %end : index
|
||||||
|
cf.cond_br %1, ^bb3, ^bb4
|
||||||
|
|
||||||
|
// CHECK: ^bb3: // pred: ^bb2
|
||||||
|
// CHECK-NEXT: {{.*}} = spv.IAdd {{.*}}, {{.*}} : i32
|
||||||
|
// CHECK-NEXT: spv.Branch ^bb2({{.*}} : i32)
|
||||||
|
^bb3: // pred: ^bb2
|
||||||
|
%2 = arith.addi %0, %step : index
|
||||||
|
cf.br ^bb2(%2 : index)
|
||||||
|
|
||||||
|
// CHECK: ^bb4: // pred: ^bb2
|
||||||
|
^bb4: // pred: ^bb2
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -168,16 +168,16 @@ gpu.module @test_module {
|
||||||
%c128 = arith.constant 128 : index
|
%c128 = arith.constant 128 : index
|
||||||
%c32 = arith.constant 32 : index
|
%c32 = arith.constant 32 : index
|
||||||
%0 = gpu.subgroup_mma_load_matrix %arg2[%c0, %c0] {leadDimension = 128 : index} : memref<128x128xf16> -> !gpu.mma_matrix<16x16xf16, "COp">
|
%0 = gpu.subgroup_mma_load_matrix %arg2[%c0, %c0] {leadDimension = 128 : index} : memref<128x128xf16> -> !gpu.mma_matrix<16x16xf16, "COp">
|
||||||
br ^bb1(%c0, %0 : index, !gpu.mma_matrix<16x16xf16, "COp">)
|
cf.br ^bb1(%c0, %0 : index, !gpu.mma_matrix<16x16xf16, "COp">)
|
||||||
^bb1(%1: index, %2: !gpu.mma_matrix<16x16xf16, "COp">): // 2 preds: ^bb0, ^bb2
|
^bb1(%1: index, %2: !gpu.mma_matrix<16x16xf16, "COp">): // 2 preds: ^bb0, ^bb2
|
||||||
%3 = arith.cmpi slt, %1, %c128 : index
|
%3 = arith.cmpi slt, %1, %c128 : index
|
||||||
cond_br %3, ^bb2, ^bb3
|
cf.cond_br %3, ^bb2, ^bb3
|
||||||
^bb2: // pred: ^bb1
|
^bb2: // pred: ^bb1
|
||||||
%4 = gpu.subgroup_mma_load_matrix %arg0[%c0, %1] {leadDimension = 128 : index} : memref<128x128xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
|
%4 = gpu.subgroup_mma_load_matrix %arg0[%c0, %1] {leadDimension = 128 : index} : memref<128x128xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
|
||||||
%5 = gpu.subgroup_mma_load_matrix %arg1[%1, %c0] {leadDimension = 128 : index} : memref<128x128xf16> -> !gpu.mma_matrix<16x16xf16, "BOp">
|
%5 = gpu.subgroup_mma_load_matrix %arg1[%1, %c0] {leadDimension = 128 : index} : memref<128x128xf16> -> !gpu.mma_matrix<16x16xf16, "BOp">
|
||||||
%6 = gpu.subgroup_mma_compute %4, %5, %2 : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
|
%6 = gpu.subgroup_mma_compute %4, %5, %2 : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
|
||||||
%7 = arith.addi %1, %c32 : index
|
%7 = arith.addi %1, %c32 : index
|
||||||
br ^bb1(%7, %6 : index, !gpu.mma_matrix<16x16xf16, "COp">)
|
cf.br ^bb1(%7, %6 : index, !gpu.mma_matrix<16x16xf16, "COp">)
|
||||||
^bb3: // pred: ^bb1
|
^bb3: // pred: ^bb1
|
||||||
gpu.subgroup_mma_store_matrix %2, %arg2[%c0, %c0] {leadDimension = 128 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<128x128xf16>
|
gpu.subgroup_mma_store_matrix %2, %arg2[%c0, %c0] {leadDimension = 128 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<128x128xf16>
|
||||||
return
|
return
|
||||||
|
|
|
@ -22,17 +22,17 @@ func @branch_loop() {
|
||||||
// CHECK: omp.parallel
|
// CHECK: omp.parallel
|
||||||
omp.parallel {
|
omp.parallel {
|
||||||
// CHECK-NEXT: llvm.br ^[[BB1:.*]](%{{[0-9]+}}, %{{[0-9]+}} : i64, i64
|
// CHECK-NEXT: llvm.br ^[[BB1:.*]](%{{[0-9]+}}, %{{[0-9]+}} : i64, i64
|
||||||
br ^bb1(%start, %end : index, index)
|
cf.br ^bb1(%start, %end : index, index)
|
||||||
// CHECK-NEXT: ^[[BB1]](%[[ARG1:[0-9]+]]: i64, %[[ARG2:[0-9]+]]: i64):{{.*}}
|
// CHECK-NEXT: ^[[BB1]](%[[ARG1:[0-9]+]]: i64, %[[ARG2:[0-9]+]]: i64):{{.*}}
|
||||||
^bb1(%0: index, %1: index):
|
^bb1(%0: index, %1: index):
|
||||||
// CHECK-NEXT: %[[CMP:[0-9]+]] = llvm.icmp "slt" %[[ARG1]], %[[ARG2]] : i64
|
// CHECK-NEXT: %[[CMP:[0-9]+]] = llvm.icmp "slt" %[[ARG1]], %[[ARG2]] : i64
|
||||||
%2 = arith.cmpi slt, %0, %1 : index
|
%2 = arith.cmpi slt, %0, %1 : index
|
||||||
// CHECK-NEXT: llvm.cond_br %[[CMP]], ^[[BB2:.*]](%{{[0-9]+}}, %{{[0-9]+}} : i64, i64), ^[[BB3:.*]]
|
// CHECK-NEXT: llvm.cond_br %[[CMP]], ^[[BB2:.*]](%{{[0-9]+}}, %{{[0-9]+}} : i64, i64), ^[[BB3:.*]]
|
||||||
cond_br %2, ^bb2(%end, %end : index, index), ^bb3
|
cf.cond_br %2, ^bb2(%end, %end : index, index), ^bb3
|
||||||
// CHECK-NEXT: ^[[BB2]](%[[ARG3:[0-9]+]]: i64, %[[ARG4:[0-9]+]]: i64):
|
// CHECK-NEXT: ^[[BB2]](%[[ARG3:[0-9]+]]: i64, %[[ARG4:[0-9]+]]: i64):
|
||||||
^bb2(%3: index, %4: index):
|
^bb2(%3: index, %4: index):
|
||||||
// CHECK-NEXT: llvm.br ^[[BB1]](%[[ARG3]], %[[ARG4]] : i64, i64)
|
// CHECK-NEXT: llvm.br ^[[BB1]](%[[ARG3]], %[[ARG4]] : i64, i64)
|
||||||
br ^bb1(%3, %4 : index, index)
|
cf.br ^bb1(%3, %4 : index, index)
|
||||||
// CHECK-NEXT: ^[[BB3]]:
|
// CHECK-NEXT: ^[[BB3]]:
|
||||||
^bb3:
|
^bb3:
|
||||||
omp.flush
|
omp.flush
|
||||||
|
|
|
@ -1,14 +1,14 @@
|
||||||
// RUN: mlir-opt -allow-unregistered-dialect -convert-scf-to-std %s | FileCheck %s
|
// RUN: mlir-opt -allow-unregistered-dialect -convert-scf-to-cf %s | FileCheck %s
|
||||||
|
|
||||||
// CHECK-LABEL: func @simple_std_for_loop(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index) {
|
// CHECK-LABEL: func @simple_std_for_loop(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index) {
|
||||||
// CHECK-NEXT: br ^bb1(%{{.*}} : index)
|
// CHECK-NEXT: cf.br ^bb1(%{{.*}} : index)
|
||||||
// CHECK-NEXT: ^bb1(%{{.*}}: index): // 2 preds: ^bb0, ^bb2
|
// CHECK-NEXT: ^bb1(%{{.*}}: index): // 2 preds: ^bb0, ^bb2
|
||||||
// CHECK-NEXT: %{{.*}} = arith.cmpi slt, %{{.*}}, %{{.*}} : index
|
// CHECK-NEXT: %{{.*}} = arith.cmpi slt, %{{.*}}, %{{.*}} : index
|
||||||
// CHECK-NEXT: cond_br %{{.*}}, ^bb2, ^bb3
|
// CHECK-NEXT: cf.cond_br %{{.*}}, ^bb2, ^bb3
|
||||||
// CHECK-NEXT: ^bb2: // pred: ^bb1
|
// CHECK-NEXT: ^bb2: // pred: ^bb1
|
||||||
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index
|
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index
|
||||||
// CHECK-NEXT: %[[iv:.*]] = arith.addi %{{.*}}, %{{.*}} : index
|
// CHECK-NEXT: %[[iv:.*]] = arith.addi %{{.*}}, %{{.*}} : index
|
||||||
// CHECK-NEXT: br ^bb1(%[[iv]] : index)
|
// CHECK-NEXT: cf.br ^bb1(%[[iv]] : index)
|
||||||
// CHECK-NEXT: ^bb3: // pred: ^bb1
|
// CHECK-NEXT: ^bb3: // pred: ^bb1
|
||||||
// CHECK-NEXT: return
|
// CHECK-NEXT: return
|
||||||
func @simple_std_for_loop(%arg0 : index, %arg1 : index, %arg2 : index) {
|
func @simple_std_for_loop(%arg0 : index, %arg1 : index, %arg2 : index) {
|
||||||
|
@ -19,23 +19,23 @@ func @simple_std_for_loop(%arg0 : index, %arg1 : index, %arg2 : index) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @simple_std_2_for_loops(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index) {
|
// CHECK-LABEL: func @simple_std_2_for_loops(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index) {
|
||||||
// CHECK-NEXT: br ^bb1(%{{.*}} : index)
|
// CHECK-NEXT: cf.br ^bb1(%{{.*}} : index)
|
||||||
// CHECK-NEXT: ^bb1(%[[ub0:.*]]: index): // 2 preds: ^bb0, ^bb5
|
// CHECK-NEXT: ^bb1(%[[ub0:.*]]: index): // 2 preds: ^bb0, ^bb5
|
||||||
// CHECK-NEXT: %[[cond0:.*]] = arith.cmpi slt, %[[ub0]], %{{.*}} : index
|
// CHECK-NEXT: %[[cond0:.*]] = arith.cmpi slt, %[[ub0]], %{{.*}} : index
|
||||||
// CHECK-NEXT: cond_br %[[cond0]], ^bb2, ^bb6
|
// CHECK-NEXT: cf.cond_br %[[cond0]], ^bb2, ^bb6
|
||||||
// CHECK-NEXT: ^bb2: // pred: ^bb1
|
// CHECK-NEXT: ^bb2: // pred: ^bb1
|
||||||
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index
|
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index
|
||||||
// CHECK-NEXT: br ^bb3(%{{.*}} : index)
|
// CHECK-NEXT: cf.br ^bb3(%{{.*}} : index)
|
||||||
// CHECK-NEXT: ^bb3(%[[ub1:.*]]: index): // 2 preds: ^bb2, ^bb4
|
// CHECK-NEXT: ^bb3(%[[ub1:.*]]: index): // 2 preds: ^bb2, ^bb4
|
||||||
// CHECK-NEXT: %[[cond1:.*]] = arith.cmpi slt, %{{.*}}, %{{.*}} : index
|
// CHECK-NEXT: %[[cond1:.*]] = arith.cmpi slt, %{{.*}}, %{{.*}} : index
|
||||||
// CHECK-NEXT: cond_br %[[cond1]], ^bb4, ^bb5
|
// CHECK-NEXT: cf.cond_br %[[cond1]], ^bb4, ^bb5
|
||||||
// CHECK-NEXT: ^bb4: // pred: ^bb3
|
// CHECK-NEXT: ^bb4: // pred: ^bb3
|
||||||
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index
|
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index
|
||||||
// CHECK-NEXT: %[[iv1:.*]] = arith.addi %{{.*}}, %{{.*}} : index
|
// CHECK-NEXT: %[[iv1:.*]] = arith.addi %{{.*}}, %{{.*}} : index
|
||||||
// CHECK-NEXT: br ^bb3(%[[iv1]] : index)
|
// CHECK-NEXT: cf.br ^bb3(%[[iv1]] : index)
|
||||||
// CHECK-NEXT: ^bb5: // pred: ^bb3
|
// CHECK-NEXT: ^bb5: // pred: ^bb3
|
||||||
// CHECK-NEXT: %[[iv0:.*]] = arith.addi %{{.*}}, %{{.*}} : index
|
// CHECK-NEXT: %[[iv0:.*]] = arith.addi %{{.*}}, %{{.*}} : index
|
||||||
// CHECK-NEXT: br ^bb1(%[[iv0]] : index)
|
// CHECK-NEXT: cf.br ^bb1(%[[iv0]] : index)
|
||||||
// CHECK-NEXT: ^bb6: // pred: ^bb1
|
// CHECK-NEXT: ^bb6: // pred: ^bb1
|
||||||
// CHECK-NEXT: return
|
// CHECK-NEXT: return
|
||||||
func @simple_std_2_for_loops(%arg0 : index, %arg1 : index, %arg2 : index) {
|
func @simple_std_2_for_loops(%arg0 : index, %arg1 : index, %arg2 : index) {
|
||||||
|
@ -49,10 +49,10 @@ func @simple_std_2_for_loops(%arg0 : index, %arg1 : index, %arg2 : index) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @simple_std_if(%{{.*}}: i1) {
|
// CHECK-LABEL: func @simple_std_if(%{{.*}}: i1) {
|
||||||
// CHECK-NEXT: cond_br %{{.*}}, ^bb1, ^bb2
|
// CHECK-NEXT: cf.cond_br %{{.*}}, ^bb1, ^bb2
|
||||||
// CHECK-NEXT: ^bb1: // pred: ^bb0
|
// CHECK-NEXT: ^bb1: // pred: ^bb0
|
||||||
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index
|
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index
|
||||||
// CHECK-NEXT: br ^bb2
|
// CHECK-NEXT: cf.br ^bb2
|
||||||
// CHECK-NEXT: ^bb2: // 2 preds: ^bb0, ^bb1
|
// CHECK-NEXT: ^bb2: // 2 preds: ^bb0, ^bb1
|
||||||
// CHECK-NEXT: return
|
// CHECK-NEXT: return
|
||||||
func @simple_std_if(%arg0: i1) {
|
func @simple_std_if(%arg0: i1) {
|
||||||
|
@ -63,13 +63,13 @@ func @simple_std_if(%arg0: i1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @simple_std_if_else(%{{.*}}: i1) {
|
// CHECK-LABEL: func @simple_std_if_else(%{{.*}}: i1) {
|
||||||
// CHECK-NEXT: cond_br %{{.*}}, ^bb1, ^bb2
|
// CHECK-NEXT: cf.cond_br %{{.*}}, ^bb1, ^bb2
|
||||||
// CHECK-NEXT: ^bb1: // pred: ^bb0
|
// CHECK-NEXT: ^bb1: // pred: ^bb0
|
||||||
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index
|
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index
|
||||||
// CHECK-NEXT: br ^bb3
|
// CHECK-NEXT: cf.br ^bb3
|
||||||
// CHECK-NEXT: ^bb2: // pred: ^bb0
|
// CHECK-NEXT: ^bb2: // pred: ^bb0
|
||||||
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index
|
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index
|
||||||
// CHECK-NEXT: br ^bb3
|
// CHECK-NEXT: cf.br ^bb3
|
||||||
// CHECK-NEXT: ^bb3: // 2 preds: ^bb1, ^bb2
|
// CHECK-NEXT: ^bb3: // 2 preds: ^bb1, ^bb2
|
||||||
// CHECK-NEXT: return
|
// CHECK-NEXT: return
|
||||||
func @simple_std_if_else(%arg0: i1) {
|
func @simple_std_if_else(%arg0: i1) {
|
||||||
|
@ -82,18 +82,18 @@ func @simple_std_if_else(%arg0: i1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @simple_std_2_ifs(%{{.*}}: i1) {
|
// CHECK-LABEL: func @simple_std_2_ifs(%{{.*}}: i1) {
|
||||||
// CHECK-NEXT: cond_br %{{.*}}, ^bb1, ^bb5
|
// CHECK-NEXT: cf.cond_br %{{.*}}, ^bb1, ^bb5
|
||||||
// CHECK-NEXT: ^bb1: // pred: ^bb0
|
// CHECK-NEXT: ^bb1: // pred: ^bb0
|
||||||
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index
|
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index
|
||||||
// CHECK-NEXT: cond_br %{{.*}}, ^bb2, ^bb3
|
// CHECK-NEXT: cf.cond_br %{{.*}}, ^bb2, ^bb3
|
||||||
// CHECK-NEXT: ^bb2: // pred: ^bb1
|
// CHECK-NEXT: ^bb2: // pred: ^bb1
|
||||||
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index
|
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index
|
||||||
// CHECK-NEXT: br ^bb4
|
// CHECK-NEXT: cf.br ^bb4
|
||||||
// CHECK-NEXT: ^bb3: // pred: ^bb1
|
// CHECK-NEXT: ^bb3: // pred: ^bb1
|
||||||
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index
|
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index
|
||||||
// CHECK-NEXT: br ^bb4
|
// CHECK-NEXT: cf.br ^bb4
|
||||||
// CHECK-NEXT: ^bb4: // 2 preds: ^bb2, ^bb3
|
// CHECK-NEXT: ^bb4: // 2 preds: ^bb2, ^bb3
|
||||||
// CHECK-NEXT: br ^bb5
|
// CHECK-NEXT: cf.br ^bb5
|
||||||
// CHECK-NEXT: ^bb5: // 2 preds: ^bb0, ^bb4
|
// CHECK-NEXT: ^bb5: // 2 preds: ^bb0, ^bb4
|
||||||
// CHECK-NEXT: return
|
// CHECK-NEXT: return
|
||||||
func @simple_std_2_ifs(%arg0: i1) {
|
func @simple_std_2_ifs(%arg0: i1) {
|
||||||
|
@ -109,27 +109,27 @@ func @simple_std_2_ifs(%arg0: i1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @simple_std_for_loop_with_2_ifs(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: i1) {
|
// CHECK-LABEL: func @simple_std_for_loop_with_2_ifs(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: i1) {
|
||||||
// CHECK-NEXT: br ^bb1(%{{.*}} : index)
|
// CHECK-NEXT: cf.br ^bb1(%{{.*}} : index)
|
||||||
// CHECK-NEXT: ^bb1(%{{.*}}: index): // 2 preds: ^bb0, ^bb7
|
// CHECK-NEXT: ^bb1(%{{.*}}: index): // 2 preds: ^bb0, ^bb7
|
||||||
// CHECK-NEXT: %{{.*}} = arith.cmpi slt, %{{.*}}, %{{.*}} : index
|
// CHECK-NEXT: %{{.*}} = arith.cmpi slt, %{{.*}}, %{{.*}} : index
|
||||||
// CHECK-NEXT: cond_br %{{.*}}, ^bb2, ^bb8
|
// CHECK-NEXT: cf.cond_br %{{.*}}, ^bb2, ^bb8
|
||||||
// CHECK-NEXT: ^bb2: // pred: ^bb1
|
// CHECK-NEXT: ^bb2: // pred: ^bb1
|
||||||
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index
|
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index
|
||||||
// CHECK-NEXT: cond_br %{{.*}}, ^bb3, ^bb7
|
// CHECK-NEXT: cf.cond_br %{{.*}}, ^bb3, ^bb7
|
||||||
// CHECK-NEXT: ^bb3: // pred: ^bb2
|
// CHECK-NEXT: ^bb3: // pred: ^bb2
|
||||||
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index
|
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index
|
||||||
// CHECK-NEXT: cond_br %{{.*}}, ^bb4, ^bb5
|
// CHECK-NEXT: cf.cond_br %{{.*}}, ^bb4, ^bb5
|
||||||
// CHECK-NEXT: ^bb4: // pred: ^bb3
|
// CHECK-NEXT: ^bb4: // pred: ^bb3
|
||||||
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index
|
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index
|
||||||
// CHECK-NEXT: br ^bb6
|
// CHECK-NEXT: cf.br ^bb6
|
||||||
// CHECK-NEXT: ^bb5: // pred: ^bb3
|
// CHECK-NEXT: ^bb5: // pred: ^bb3
|
||||||
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index
|
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index
|
||||||
// CHECK-NEXT: br ^bb6
|
// CHECK-NEXT: cf.br ^bb6
|
||||||
// CHECK-NEXT: ^bb6: // 2 preds: ^bb4, ^bb5
|
// CHECK-NEXT: ^bb6: // 2 preds: ^bb4, ^bb5
|
||||||
// CHECK-NEXT: br ^bb7
|
// CHECK-NEXT: cf.br ^bb7
|
||||||
// CHECK-NEXT: ^bb7: // 2 preds: ^bb2, ^bb6
|
// CHECK-NEXT: ^bb7: // 2 preds: ^bb2, ^bb6
|
||||||
// CHECK-NEXT: %[[iv0:.*]] = arith.addi %{{.*}}, %{{.*}} : index
|
// CHECK-NEXT: %[[iv0:.*]] = arith.addi %{{.*}}, %{{.*}} : index
|
||||||
// CHECK-NEXT: br ^bb1(%[[iv0]] : index)
|
// CHECK-NEXT: cf.br ^bb1(%[[iv0]] : index)
|
||||||
// CHECK-NEXT: ^bb8: // pred: ^bb1
|
// CHECK-NEXT: ^bb8: // pred: ^bb1
|
||||||
// CHECK-NEXT: return
|
// CHECK-NEXT: return
|
||||||
// CHECK-NEXT: }
|
// CHECK-NEXT: }
|
||||||
|
@ -150,12 +150,12 @@ func @simple_std_for_loop_with_2_ifs(%arg0 : index, %arg1 : index, %arg2 : index
|
||||||
|
|
||||||
// CHECK-LABEL: func @simple_if_yield
|
// CHECK-LABEL: func @simple_if_yield
|
||||||
func @simple_if_yield(%arg0: i1) -> (i1, i1) {
|
func @simple_if_yield(%arg0: i1) -> (i1, i1) {
|
||||||
// CHECK: cond_br %{{.*}}, ^[[then:.*]], ^[[else:.*]]
|
// CHECK: cf.cond_br %{{.*}}, ^[[then:.*]], ^[[else:.*]]
|
||||||
%0:2 = scf.if %arg0 -> (i1, i1) {
|
%0:2 = scf.if %arg0 -> (i1, i1) {
|
||||||
// CHECK: ^[[then]]:
|
// CHECK: ^[[then]]:
|
||||||
// CHECK: %[[v0:.*]] = arith.constant false
|
// CHECK: %[[v0:.*]] = arith.constant false
|
||||||
// CHECK: %[[v1:.*]] = arith.constant true
|
// CHECK: %[[v1:.*]] = arith.constant true
|
||||||
// CHECK: br ^[[dom:.*]](%[[v0]], %[[v1]] : i1, i1)
|
// CHECK: cf.br ^[[dom:.*]](%[[v0]], %[[v1]] : i1, i1)
|
||||||
%c0 = arith.constant false
|
%c0 = arith.constant false
|
||||||
%c1 = arith.constant true
|
%c1 = arith.constant true
|
||||||
scf.yield %c0, %c1 : i1, i1
|
scf.yield %c0, %c1 : i1, i1
|
||||||
|
@ -163,13 +163,13 @@ func @simple_if_yield(%arg0: i1) -> (i1, i1) {
|
||||||
// CHECK: ^[[else]]:
|
// CHECK: ^[[else]]:
|
||||||
// CHECK: %[[v2:.*]] = arith.constant false
|
// CHECK: %[[v2:.*]] = arith.constant false
|
||||||
// CHECK: %[[v3:.*]] = arith.constant true
|
// CHECK: %[[v3:.*]] = arith.constant true
|
||||||
// CHECK: br ^[[dom]](%[[v3]], %[[v2]] : i1, i1)
|
// CHECK: cf.br ^[[dom]](%[[v3]], %[[v2]] : i1, i1)
|
||||||
%c0 = arith.constant false
|
%c0 = arith.constant false
|
||||||
%c1 = arith.constant true
|
%c1 = arith.constant true
|
||||||
scf.yield %c1, %c0 : i1, i1
|
scf.yield %c1, %c0 : i1, i1
|
||||||
}
|
}
|
||||||
// CHECK: ^[[dom]](%[[arg1:.*]]: i1, %[[arg2:.*]]: i1):
|
// CHECK: ^[[dom]](%[[arg1:.*]]: i1, %[[arg2:.*]]: i1):
|
||||||
// CHECK: br ^[[cont:.*]]
|
// CHECK: cf.br ^[[cont:.*]]
|
||||||
// CHECK: ^[[cont]]:
|
// CHECK: ^[[cont]]:
|
||||||
// CHECK: return %[[arg1]], %[[arg2]]
|
// CHECK: return %[[arg1]], %[[arg2]]
|
||||||
return %0#0, %0#1 : i1, i1
|
return %0#0, %0#1 : i1, i1
|
||||||
|
@ -177,49 +177,49 @@ func @simple_if_yield(%arg0: i1) -> (i1, i1) {
|
||||||
|
|
||||||
// CHECK-LABEL: func @nested_if_yield
|
// CHECK-LABEL: func @nested_if_yield
|
||||||
func @nested_if_yield(%arg0: i1) -> (index) {
|
func @nested_if_yield(%arg0: i1) -> (index) {
|
||||||
// CHECK: cond_br %{{.*}}, ^[[first_then:.*]], ^[[first_else:.*]]
|
// CHECK: cf.cond_br %{{.*}}, ^[[first_then:.*]], ^[[first_else:.*]]
|
||||||
%0 = scf.if %arg0 -> i1 {
|
%0 = scf.if %arg0 -> i1 {
|
||||||
// CHECK: ^[[first_then]]:
|
// CHECK: ^[[first_then]]:
|
||||||
%1 = arith.constant true
|
%1 = arith.constant true
|
||||||
// CHECK: br ^[[first_dom:.*]]({{.*}})
|
// CHECK: cf.br ^[[first_dom:.*]]({{.*}})
|
||||||
scf.yield %1 : i1
|
scf.yield %1 : i1
|
||||||
} else {
|
} else {
|
||||||
// CHECK: ^[[first_else]]:
|
// CHECK: ^[[first_else]]:
|
||||||
%2 = arith.constant false
|
%2 = arith.constant false
|
||||||
// CHECK: br ^[[first_dom]]({{.*}})
|
// CHECK: cf.br ^[[first_dom]]({{.*}})
|
||||||
scf.yield %2 : i1
|
scf.yield %2 : i1
|
||||||
}
|
}
|
||||||
// CHECK: ^[[first_dom]](%[[arg1:.*]]: i1):
|
// CHECK: ^[[first_dom]](%[[arg1:.*]]: i1):
|
||||||
// CHECK: br ^[[first_cont:.*]]
|
// CHECK: cf.br ^[[first_cont:.*]]
|
||||||
// CHECK: ^[[first_cont]]:
|
// CHECK: ^[[first_cont]]:
|
||||||
// CHECK: cond_br %[[arg1]], ^[[second_outer_then:.*]], ^[[second_outer_else:.*]]
|
// CHECK: cf.cond_br %[[arg1]], ^[[second_outer_then:.*]], ^[[second_outer_else:.*]]
|
||||||
%1 = scf.if %0 -> index {
|
%1 = scf.if %0 -> index {
|
||||||
// CHECK: ^[[second_outer_then]]:
|
// CHECK: ^[[second_outer_then]]:
|
||||||
// CHECK: cond_br %arg0, ^[[second_inner_then:.*]], ^[[second_inner_else:.*]]
|
// CHECK: cf.cond_br %arg0, ^[[second_inner_then:.*]], ^[[second_inner_else:.*]]
|
||||||
%3 = scf.if %arg0 -> index {
|
%3 = scf.if %arg0 -> index {
|
||||||
// CHECK: ^[[second_inner_then]]:
|
// CHECK: ^[[second_inner_then]]:
|
||||||
%4 = arith.constant 40 : index
|
%4 = arith.constant 40 : index
|
||||||
// CHECK: br ^[[second_inner_dom:.*]]({{.*}})
|
// CHECK: cf.br ^[[second_inner_dom:.*]]({{.*}})
|
||||||
scf.yield %4 : index
|
scf.yield %4 : index
|
||||||
} else {
|
} else {
|
||||||
// CHECK: ^[[second_inner_else]]:
|
// CHECK: ^[[second_inner_else]]:
|
||||||
%5 = arith.constant 41 : index
|
%5 = arith.constant 41 : index
|
||||||
// CHECK: br ^[[second_inner_dom]]({{.*}})
|
// CHECK: cf.br ^[[second_inner_dom]]({{.*}})
|
||||||
scf.yield %5 : index
|
scf.yield %5 : index
|
||||||
}
|
}
|
||||||
// CHECK: ^[[second_inner_dom]](%[[arg2:.*]]: index):
|
// CHECK: ^[[second_inner_dom]](%[[arg2:.*]]: index):
|
||||||
// CHECK: br ^[[second_inner_cont:.*]]
|
// CHECK: cf.br ^[[second_inner_cont:.*]]
|
||||||
// CHECK: ^[[second_inner_cont]]:
|
// CHECK: ^[[second_inner_cont]]:
|
||||||
// CHECK: br ^[[second_outer_dom:.*]]({{.*}})
|
// CHECK: cf.br ^[[second_outer_dom:.*]]({{.*}})
|
||||||
scf.yield %3 : index
|
scf.yield %3 : index
|
||||||
} else {
|
} else {
|
||||||
// CHECK: ^[[second_outer_else]]:
|
// CHECK: ^[[second_outer_else]]:
|
||||||
%6 = arith.constant 42 : index
|
%6 = arith.constant 42 : index
|
||||||
// CHECK: br ^[[second_outer_dom]]({{.*}}
|
// CHECK: cf.br ^[[second_outer_dom]]({{.*}}
|
||||||
scf.yield %6 : index
|
scf.yield %6 : index
|
||||||
}
|
}
|
||||||
// CHECK: ^[[second_outer_dom]](%[[arg3:.*]]: index):
|
// CHECK: ^[[second_outer_dom]](%[[arg3:.*]]: index):
|
||||||
// CHECK: br ^[[second_outer_cont:.*]]
|
// CHECK: cf.br ^[[second_outer_cont:.*]]
|
||||||
// CHECK: ^[[second_outer_cont]]:
|
// CHECK: ^[[second_outer_cont]]:
|
||||||
// CHECK: return %[[arg3]] : index
|
// CHECK: return %[[arg3]] : index
|
||||||
return %1 : index
|
return %1 : index
|
||||||
|
@ -228,22 +228,22 @@ func @nested_if_yield(%arg0: i1) -> (index) {
|
||||||
// CHECK-LABEL: func @parallel_loop(
|
// CHECK-LABEL: func @parallel_loop(
|
||||||
// CHECK-SAME: [[VAL_0:%.*]]: index, [[VAL_1:%.*]]: index, [[VAL_2:%.*]]: index, [[VAL_3:%.*]]: index, [[VAL_4:%.*]]: index) {
|
// CHECK-SAME: [[VAL_0:%.*]]: index, [[VAL_1:%.*]]: index, [[VAL_2:%.*]]: index, [[VAL_3:%.*]]: index, [[VAL_4:%.*]]: index) {
|
||||||
// CHECK: [[VAL_5:%.*]] = arith.constant 1 : index
|
// CHECK: [[VAL_5:%.*]] = arith.constant 1 : index
|
||||||
// CHECK: br ^bb1([[VAL_0]] : index)
|
// CHECK: cf.br ^bb1([[VAL_0]] : index)
|
||||||
// CHECK: ^bb1([[VAL_6:%.*]]: index):
|
// CHECK: ^bb1([[VAL_6:%.*]]: index):
|
||||||
// CHECK: [[VAL_7:%.*]] = arith.cmpi slt, [[VAL_6]], [[VAL_2]] : index
|
// CHECK: [[VAL_7:%.*]] = arith.cmpi slt, [[VAL_6]], [[VAL_2]] : index
|
||||||
// CHECK: cond_br [[VAL_7]], ^bb2, ^bb6
|
// CHECK: cf.cond_br [[VAL_7]], ^bb2, ^bb6
|
||||||
// CHECK: ^bb2:
|
// CHECK: ^bb2:
|
||||||
// CHECK: br ^bb3([[VAL_1]] : index)
|
// CHECK: cf.br ^bb3([[VAL_1]] : index)
|
||||||
// CHECK: ^bb3([[VAL_8:%.*]]: index):
|
// CHECK: ^bb3([[VAL_8:%.*]]: index):
|
||||||
// CHECK: [[VAL_9:%.*]] = arith.cmpi slt, [[VAL_8]], [[VAL_3]] : index
|
// CHECK: [[VAL_9:%.*]] = arith.cmpi slt, [[VAL_8]], [[VAL_3]] : index
|
||||||
// CHECK: cond_br [[VAL_9]], ^bb4, ^bb5
|
// CHECK: cf.cond_br [[VAL_9]], ^bb4, ^bb5
|
||||||
// CHECK: ^bb4:
|
// CHECK: ^bb4:
|
||||||
// CHECK: [[VAL_10:%.*]] = arith.constant 1 : index
|
// CHECK: [[VAL_10:%.*]] = arith.constant 1 : index
|
||||||
// CHECK: [[VAL_11:%.*]] = arith.addi [[VAL_8]], [[VAL_5]] : index
|
// CHECK: [[VAL_11:%.*]] = arith.addi [[VAL_8]], [[VAL_5]] : index
|
||||||
// CHECK: br ^bb3([[VAL_11]] : index)
|
// CHECK: cf.br ^bb3([[VAL_11]] : index)
|
||||||
// CHECK: ^bb5:
|
// CHECK: ^bb5:
|
||||||
// CHECK: [[VAL_12:%.*]] = arith.addi [[VAL_6]], [[VAL_4]] : index
|
// CHECK: [[VAL_12:%.*]] = arith.addi [[VAL_6]], [[VAL_4]] : index
|
||||||
// CHECK: br ^bb1([[VAL_12]] : index)
|
// CHECK: cf.br ^bb1([[VAL_12]] : index)
|
||||||
// CHECK: ^bb6:
|
// CHECK: ^bb6:
|
||||||
// CHECK: return
|
// CHECK: return
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
|
@ -262,16 +262,16 @@ func @parallel_loop(%arg0 : index, %arg1 : index, %arg2 : index,
|
||||||
// CHECK-SAME: (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index)
|
// CHECK-SAME: (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index)
|
||||||
// CHECK: %[[INIT0:.*]] = arith.constant 0
|
// CHECK: %[[INIT0:.*]] = arith.constant 0
|
||||||
// CHECK: %[[INIT1:.*]] = arith.constant 1
|
// CHECK: %[[INIT1:.*]] = arith.constant 1
|
||||||
// CHECK: br ^[[COND:.*]](%[[LB]], %[[INIT0]], %[[INIT1]] : index, f32, f32)
|
// CHECK: cf.br ^[[COND:.*]](%[[LB]], %[[INIT0]], %[[INIT1]] : index, f32, f32)
|
||||||
//
|
//
|
||||||
// CHECK: ^[[COND]](%[[ITER:.*]]: index, %[[ITER_ARG0:.*]]: f32, %[[ITER_ARG1:.*]]: f32):
|
// CHECK: ^[[COND]](%[[ITER:.*]]: index, %[[ITER_ARG0:.*]]: f32, %[[ITER_ARG1:.*]]: f32):
|
||||||
// CHECK: %[[CMP:.*]] = arith.cmpi slt, %[[ITER]], %[[UB]] : index
|
// CHECK: %[[CMP:.*]] = arith.cmpi slt, %[[ITER]], %[[UB]] : index
|
||||||
// CHECK: cond_br %[[CMP]], ^[[BODY:.*]], ^[[CONTINUE:.*]]
|
// CHECK: cf.cond_br %[[CMP]], ^[[BODY:.*]], ^[[CONTINUE:.*]]
|
||||||
//
|
//
|
||||||
// CHECK: ^[[BODY]]:
|
// CHECK: ^[[BODY]]:
|
||||||
// CHECK: %[[SUM:.*]] = arith.addf %[[ITER_ARG0]], %[[ITER_ARG1]] : f32
|
// CHECK: %[[SUM:.*]] = arith.addf %[[ITER_ARG0]], %[[ITER_ARG1]] : f32
|
||||||
// CHECK: %[[STEPPED:.*]] = arith.addi %[[ITER]], %[[STEP]] : index
|
// CHECK: %[[STEPPED:.*]] = arith.addi %[[ITER]], %[[STEP]] : index
|
||||||
// CHECK: br ^[[COND]](%[[STEPPED]], %[[SUM]], %[[SUM]] : index, f32, f32)
|
// CHECK: cf.br ^[[COND]](%[[STEPPED]], %[[SUM]], %[[SUM]] : index, f32, f32)
|
||||||
//
|
//
|
||||||
// CHECK: ^[[CONTINUE]]:
|
// CHECK: ^[[CONTINUE]]:
|
||||||
// CHECK: return %[[ITER_ARG0]], %[[ITER_ARG1]] : f32, f32
|
// CHECK: return %[[ITER_ARG0]], %[[ITER_ARG1]] : f32, f32
|
||||||
|
@ -288,18 +288,18 @@ func @for_yield(%arg0 : index, %arg1 : index, %arg2 : index) -> (f32, f32) {
|
||||||
// CHECK-LABEL: @nested_for_yield
|
// CHECK-LABEL: @nested_for_yield
|
||||||
// CHECK-SAME: (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index)
|
// CHECK-SAME: (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index)
|
||||||
// CHECK: %[[INIT:.*]] = arith.constant
|
// CHECK: %[[INIT:.*]] = arith.constant
|
||||||
// CHECK: br ^[[COND_OUT:.*]](%[[LB]], %[[INIT]] : index, f32)
|
// CHECK: cf.br ^[[COND_OUT:.*]](%[[LB]], %[[INIT]] : index, f32)
|
||||||
// CHECK: ^[[COND_OUT]](%[[ITER_OUT:.*]]: index, %[[ARG_OUT:.*]]: f32):
|
// CHECK: ^[[COND_OUT]](%[[ITER_OUT:.*]]: index, %[[ARG_OUT:.*]]: f32):
|
||||||
// CHECK: cond_br %{{.*}}, ^[[BODY_OUT:.*]], ^[[CONT_OUT:.*]]
|
// CHECK: cf.cond_br %{{.*}}, ^[[BODY_OUT:.*]], ^[[CONT_OUT:.*]]
|
||||||
// CHECK: ^[[BODY_OUT]]:
|
// CHECK: ^[[BODY_OUT]]:
|
||||||
// CHECK: br ^[[COND_IN:.*]](%[[LB]], %[[ARG_OUT]] : index, f32)
|
// CHECK: cf.br ^[[COND_IN:.*]](%[[LB]], %[[ARG_OUT]] : index, f32)
|
||||||
// CHECK: ^[[COND_IN]](%[[ITER_IN:.*]]: index, %[[ARG_IN:.*]]: f32):
|
// CHECK: ^[[COND_IN]](%[[ITER_IN:.*]]: index, %[[ARG_IN:.*]]: f32):
|
||||||
// CHECK: cond_br %{{.*}}, ^[[BODY_IN:.*]], ^[[CONT_IN:.*]]
|
// CHECK: cf.cond_br %{{.*}}, ^[[BODY_IN:.*]], ^[[CONT_IN:.*]]
|
||||||
// CHECK: ^[[BODY_IN]]
|
// CHECK: ^[[BODY_IN]]
|
||||||
// CHECK: %[[RES:.*]] = arith.addf
|
// CHECK: %[[RES:.*]] = arith.addf
|
||||||
// CHECK: br ^[[COND_IN]](%{{.*}}, %[[RES]] : index, f32)
|
// CHECK: cf.br ^[[COND_IN]](%{{.*}}, %[[RES]] : index, f32)
|
||||||
// CHECK: ^[[CONT_IN]]:
|
// CHECK: ^[[CONT_IN]]:
|
||||||
// CHECK: br ^[[COND_OUT]](%{{.*}}, %[[ARG_IN]] : index, f32)
|
// CHECK: cf.br ^[[COND_OUT]](%{{.*}}, %[[ARG_IN]] : index, f32)
|
||||||
// CHECK: ^[[CONT_OUT]]:
|
// CHECK: ^[[CONT_OUT]]:
|
||||||
// CHECK: return %[[ARG_OUT]] : f32
|
// CHECK: return %[[ARG_OUT]] : f32
|
||||||
func @nested_for_yield(%arg0 : index, %arg1 : index, %arg2 : index) -> f32 {
|
func @nested_for_yield(%arg0 : index, %arg1 : index, %arg2 : index) -> f32 {
|
||||||
|
@ -325,13 +325,13 @@ func @simple_parallel_reduce_loop(%arg0: index, %arg1: index,
|
||||||
// passed across as a block argument.
|
// passed across as a block argument.
|
||||||
|
|
||||||
// Branch to the condition block passing in the initial reduction value.
|
// Branch to the condition block passing in the initial reduction value.
|
||||||
// CHECK: br ^[[COND:.*]](%[[LB]], %[[INIT]]
|
// CHECK: cf.br ^[[COND:.*]](%[[LB]], %[[INIT]]
|
||||||
|
|
||||||
// Condition branch takes as arguments the current value of the iteration
|
// Condition branch takes as arguments the current value of the iteration
|
||||||
// variable and the current partially reduced value.
|
// variable and the current partially reduced value.
|
||||||
// CHECK: ^[[COND]](%[[ITER:.*]]: index, %[[ITER_ARG:.*]]: f32
|
// CHECK: ^[[COND]](%[[ITER:.*]]: index, %[[ITER_ARG:.*]]: f32
|
||||||
// CHECK: %[[COMP:.*]] = arith.cmpi slt, %[[ITER]], %[[UB]]
|
// CHECK: %[[COMP:.*]] = arith.cmpi slt, %[[ITER]], %[[UB]]
|
||||||
// CHECK: cond_br %[[COMP]], ^[[BODY:.*]], ^[[CONTINUE:.*]]
|
// CHECK: cf.cond_br %[[COMP]], ^[[BODY:.*]], ^[[CONTINUE:.*]]
|
||||||
|
|
||||||
// Bodies of scf.reduce operations are folded into the main loop body. The
|
// Bodies of scf.reduce operations are folded into the main loop body. The
|
||||||
// result of this partial reduction is passed as argument to the condition
|
// result of this partial reduction is passed as argument to the condition
|
||||||
|
@ -340,7 +340,7 @@ func @simple_parallel_reduce_loop(%arg0: index, %arg1: index,
|
||||||
// CHECK: %[[CST:.*]] = arith.constant 4.2
|
// CHECK: %[[CST:.*]] = arith.constant 4.2
|
||||||
// CHECK: %[[PROD:.*]] = arith.mulf %[[ITER_ARG]], %[[CST]]
|
// CHECK: %[[PROD:.*]] = arith.mulf %[[ITER_ARG]], %[[CST]]
|
||||||
// CHECK: %[[INCR:.*]] = arith.addi %[[ITER]], %[[STEP]]
|
// CHECK: %[[INCR:.*]] = arith.addi %[[ITER]], %[[STEP]]
|
||||||
// CHECK: br ^[[COND]](%[[INCR]], %[[PROD]]
|
// CHECK: cf.br ^[[COND]](%[[INCR]], %[[PROD]]
|
||||||
|
|
||||||
// The continuation block has access to the (last value of) reduction.
|
// The continuation block has access to the (last value of) reduction.
|
||||||
// CHECK: ^[[CONTINUE]]:
|
// CHECK: ^[[CONTINUE]]:
|
||||||
|
@ -363,19 +363,19 @@ func @parallel_reduce_loop(%arg0 : index, %arg1 : index, %arg2 : index,
|
||||||
// Multiple reduction blocks should be folded in the same body, and the
|
// Multiple reduction blocks should be folded in the same body, and the
|
||||||
// reduction value must be forwarded through block structures.
|
// reduction value must be forwarded through block structures.
|
||||||
// CHECK: %[[INIT2:.*]] = arith.constant 42
|
// CHECK: %[[INIT2:.*]] = arith.constant 42
|
||||||
// CHECK: br ^[[COND_OUT:.*]](%{{.*}}, %[[INIT1]], %[[INIT2]]
|
// CHECK: cf.br ^[[COND_OUT:.*]](%{{.*}}, %[[INIT1]], %[[INIT2]]
|
||||||
// CHECK: ^[[COND_OUT]](%{{.*}}: index, %[[ITER_ARG1_OUT:.*]]: f32, %[[ITER_ARG2_OUT:.*]]: i64
|
// CHECK: ^[[COND_OUT]](%{{.*}}: index, %[[ITER_ARG1_OUT:.*]]: f32, %[[ITER_ARG2_OUT:.*]]: i64
|
||||||
// CHECK: cond_br %{{.*}}, ^[[BODY_OUT:.*]], ^[[CONT_OUT:.*]]
|
// CHECK: cf.cond_br %{{.*}}, ^[[BODY_OUT:.*]], ^[[CONT_OUT:.*]]
|
||||||
// CHECK: ^[[BODY_OUT]]:
|
// CHECK: ^[[BODY_OUT]]:
|
||||||
// CHECK: br ^[[COND_IN:.*]](%{{.*}}, %[[ITER_ARG1_OUT]], %[[ITER_ARG2_OUT]]
|
// CHECK: cf.br ^[[COND_IN:.*]](%{{.*}}, %[[ITER_ARG1_OUT]], %[[ITER_ARG2_OUT]]
|
||||||
// CHECK: ^[[COND_IN]](%{{.*}}: index, %[[ITER_ARG1_IN:.*]]: f32, %[[ITER_ARG2_IN:.*]]: i64
|
// CHECK: ^[[COND_IN]](%{{.*}}: index, %[[ITER_ARG1_IN:.*]]: f32, %[[ITER_ARG2_IN:.*]]: i64
|
||||||
// CHECK: cond_br %{{.*}}, ^[[BODY_IN:.*]], ^[[CONT_IN:.*]]
|
// CHECK: cf.cond_br %{{.*}}, ^[[BODY_IN:.*]], ^[[CONT_IN:.*]]
|
||||||
// CHECK: ^[[BODY_IN]]:
|
// CHECK: ^[[BODY_IN]]:
|
||||||
// CHECK: %[[REDUCE1:.*]] = arith.addf %[[ITER_ARG1_IN]], %{{.*}}
|
// CHECK: %[[REDUCE1:.*]] = arith.addf %[[ITER_ARG1_IN]], %{{.*}}
|
||||||
// CHECK: %[[REDUCE2:.*]] = arith.ori %[[ITER_ARG2_IN]], %{{.*}}
|
// CHECK: %[[REDUCE2:.*]] = arith.ori %[[ITER_ARG2_IN]], %{{.*}}
|
||||||
// CHECK: br ^[[COND_IN]](%{{.*}}, %[[REDUCE1]], %[[REDUCE2]]
|
// CHECK: cf.br ^[[COND_IN]](%{{.*}}, %[[REDUCE1]], %[[REDUCE2]]
|
||||||
// CHECK: ^[[CONT_IN]]:
|
// CHECK: ^[[CONT_IN]]:
|
||||||
// CHECK: br ^[[COND_OUT]](%{{.*}}, %[[ITER_ARG1_IN]], %[[ITER_ARG2_IN]]
|
// CHECK: cf.br ^[[COND_OUT]](%{{.*}}, %[[ITER_ARG1_IN]], %[[ITER_ARG2_IN]]
|
||||||
// CHECK: ^[[CONT_OUT]]:
|
// CHECK: ^[[CONT_OUT]]:
|
||||||
// CHECK: return %[[ITER_ARG1_OUT]], %[[ITER_ARG2_OUT]]
|
// CHECK: return %[[ITER_ARG1_OUT]], %[[ITER_ARG2_OUT]]
|
||||||
%step = arith.constant 1 : index
|
%step = arith.constant 1 : index
|
||||||
|
@ -416,17 +416,17 @@ func @unknown_op_inside_loop(%arg0: index, %arg1: index, %arg2: index) {
|
||||||
// CHECK-LABEL: @minimal_while
|
// CHECK-LABEL: @minimal_while
|
||||||
func @minimal_while() {
|
func @minimal_while() {
|
||||||
// CHECK: %[[COND:.*]] = "test.make_condition"() : () -> i1
|
// CHECK: %[[COND:.*]] = "test.make_condition"() : () -> i1
|
||||||
// CHECK: br ^[[BEFORE:.*]]
|
// CHECK: cf.br ^[[BEFORE:.*]]
|
||||||
%0 = "test.make_condition"() : () -> i1
|
%0 = "test.make_condition"() : () -> i1
|
||||||
scf.while : () -> () {
|
scf.while : () -> () {
|
||||||
// CHECK: ^[[BEFORE]]:
|
// CHECK: ^[[BEFORE]]:
|
||||||
// CHECK: cond_br %[[COND]], ^[[AFTER:.*]], ^[[CONT:.*]]
|
// CHECK: cf.cond_br %[[COND]], ^[[AFTER:.*]], ^[[CONT:.*]]
|
||||||
scf.condition(%0)
|
scf.condition(%0)
|
||||||
} do {
|
} do {
|
||||||
// CHECK: ^[[AFTER]]:
|
// CHECK: ^[[AFTER]]:
|
||||||
// CHECK: "test.some_payload"() : () -> ()
|
// CHECK: "test.some_payload"() : () -> ()
|
||||||
"test.some_payload"() : () -> ()
|
"test.some_payload"() : () -> ()
|
||||||
// CHECK: br ^[[BEFORE]]
|
// CHECK: cf.br ^[[BEFORE]]
|
||||||
scf.yield
|
scf.yield
|
||||||
}
|
}
|
||||||
// CHECK: ^[[CONT]]:
|
// CHECK: ^[[CONT]]:
|
||||||
|
@ -436,16 +436,16 @@ func @minimal_while() {
|
||||||
|
|
||||||
// CHECK-LABEL: @do_while
|
// CHECK-LABEL: @do_while
|
||||||
func @do_while(%arg0: f32) {
|
func @do_while(%arg0: f32) {
|
||||||
// CHECK: br ^[[BEFORE:.*]]({{.*}}: f32)
|
// CHECK: cf.br ^[[BEFORE:.*]]({{.*}}: f32)
|
||||||
scf.while (%arg1 = %arg0) : (f32) -> (f32) {
|
scf.while (%arg1 = %arg0) : (f32) -> (f32) {
|
||||||
// CHECK: ^[[BEFORE]](%[[VAL:.*]]: f32):
|
// CHECK: ^[[BEFORE]](%[[VAL:.*]]: f32):
|
||||||
// CHECK: %[[COND:.*]] = "test.make_condition"() : () -> i1
|
// CHECK: %[[COND:.*]] = "test.make_condition"() : () -> i1
|
||||||
%0 = "test.make_condition"() : () -> i1
|
%0 = "test.make_condition"() : () -> i1
|
||||||
// CHECK: cond_br %[[COND]], ^[[BEFORE]](%[[VAL]] : f32), ^[[CONT:.*]]
|
// CHECK: cf.cond_br %[[COND]], ^[[BEFORE]](%[[VAL]] : f32), ^[[CONT:.*]]
|
||||||
scf.condition(%0) %arg1 : f32
|
scf.condition(%0) %arg1 : f32
|
||||||
} do {
|
} do {
|
||||||
^bb0(%arg2: f32):
|
^bb0(%arg2: f32):
|
||||||
// CHECK-NOT: br ^[[BEFORE]]
|
// CHECK-NOT: cf.br ^[[BEFORE]]
|
||||||
scf.yield %arg2 : f32
|
scf.yield %arg2 : f32
|
||||||
}
|
}
|
||||||
// CHECK: ^[[CONT]]:
|
// CHECK: ^[[CONT]]:
|
||||||
|
@ -460,21 +460,21 @@ func @while_values(%arg0: i32, %arg1: f32) {
|
||||||
%0 = "test.make_condition"() : () -> i1
|
%0 = "test.make_condition"() : () -> i1
|
||||||
%c0_i32 = arith.constant 0 : i32
|
%c0_i32 = arith.constant 0 : i32
|
||||||
%cst = arith.constant 0.000000e+00 : f32
|
%cst = arith.constant 0.000000e+00 : f32
|
||||||
// CHECK: br ^[[BEFORE:.*]](%[[ARG0]], %[[ARG1]] : i32, f32)
|
// CHECK: cf.br ^[[BEFORE:.*]](%[[ARG0]], %[[ARG1]] : i32, f32)
|
||||||
%1:2 = scf.while (%arg2 = %arg0, %arg3 = %arg1) : (i32, f32) -> (i64, f64) {
|
%1:2 = scf.while (%arg2 = %arg0, %arg3 = %arg1) : (i32, f32) -> (i64, f64) {
|
||||||
// CHECK: ^bb1(%[[ARG2:.*]]: i32, %[[ARG3:.]]: f32):
|
// CHECK: ^bb1(%[[ARG2:.*]]: i32, %[[ARG3:.]]: f32):
|
||||||
// CHECK: %[[VAL1:.*]] = arith.extui %[[ARG0]] : i32 to i64
|
// CHECK: %[[VAL1:.*]] = arith.extui %[[ARG0]] : i32 to i64
|
||||||
%2 = arith.extui %arg0 : i32 to i64
|
%2 = arith.extui %arg0 : i32 to i64
|
||||||
// CHECK: %[[VAL2:.*]] = arith.extf %[[ARG3]] : f32 to f64
|
// CHECK: %[[VAL2:.*]] = arith.extf %[[ARG3]] : f32 to f64
|
||||||
%3 = arith.extf %arg3 : f32 to f64
|
%3 = arith.extf %arg3 : f32 to f64
|
||||||
// CHECK: cond_br %[[COND]],
|
// CHECK: cf.cond_br %[[COND]],
|
||||||
// CHECK: ^[[AFTER:.*]](%[[VAL1]], %[[VAL2]] : i64, f64),
|
// CHECK: ^[[AFTER:.*]](%[[VAL1]], %[[VAL2]] : i64, f64),
|
||||||
// CHECK: ^[[CONT:.*]]
|
// CHECK: ^[[CONT:.*]]
|
||||||
scf.condition(%0) %2, %3 : i64, f64
|
scf.condition(%0) %2, %3 : i64, f64
|
||||||
} do {
|
} do {
|
||||||
// CHECK: ^[[AFTER]](%[[ARG4:.*]]: i64, %[[ARG5:.*]]: f64):
|
// CHECK: ^[[AFTER]](%[[ARG4:.*]]: i64, %[[ARG5:.*]]: f64):
|
||||||
^bb0(%arg2: i64, %arg3: f64):
|
^bb0(%arg2: i64, %arg3: f64):
|
||||||
// CHECK: br ^[[BEFORE]](%{{.*}}, %{{.*}} : i32, f32)
|
// CHECK: cf.br ^[[BEFORE]](%{{.*}}, %{{.*}} : i32, f32)
|
||||||
scf.yield %c0_i32, %cst : i32, f32
|
scf.yield %c0_i32, %cst : i32, f32
|
||||||
}
|
}
|
||||||
// CHECK: ^bb3:
|
// CHECK: ^bb3:
|
||||||
|
@ -484,17 +484,17 @@ func @while_values(%arg0: i32, %arg1: f32) {
|
||||||
|
|
||||||
// CHECK-LABEL: @nested_while_ops
|
// CHECK-LABEL: @nested_while_ops
|
||||||
func @nested_while_ops(%arg0: f32) -> i64 {
|
func @nested_while_ops(%arg0: f32) -> i64 {
|
||||||
// CHECK: br ^[[OUTER_BEFORE:.*]](%{{.*}} : f32)
|
// CHECK: cf.br ^[[OUTER_BEFORE:.*]](%{{.*}} : f32)
|
||||||
%0 = scf.while(%outer = %arg0) : (f32) -> i64 {
|
%0 = scf.while(%outer = %arg0) : (f32) -> i64 {
|
||||||
// CHECK: ^[[OUTER_BEFORE]](%{{.*}}: f32):
|
// CHECK: ^[[OUTER_BEFORE]](%{{.*}}: f32):
|
||||||
// CHECK: %[[OUTER_COND:.*]] = "test.outer_before_pre"() : () -> i1
|
// CHECK: %[[OUTER_COND:.*]] = "test.outer_before_pre"() : () -> i1
|
||||||
%cond = "test.outer_before_pre"() : () -> i1
|
%cond = "test.outer_before_pre"() : () -> i1
|
||||||
// CHECK: br ^[[INNER_BEFORE_BEFORE:.*]](%{{.*}} : f32)
|
// CHECK: cf.br ^[[INNER_BEFORE_BEFORE:.*]](%{{.*}} : f32)
|
||||||
%1 = scf.while(%inner = %outer) : (f32) -> i64 {
|
%1 = scf.while(%inner = %outer) : (f32) -> i64 {
|
||||||
// CHECK: ^[[INNER_BEFORE_BEFORE]](%{{.*}}: f32):
|
// CHECK: ^[[INNER_BEFORE_BEFORE]](%{{.*}}: f32):
|
||||||
// CHECK: %[[INNER1:.*]]:2 = "test.inner_before"(%{{.*}}) : (f32) -> (i1, i64)
|
// CHECK: %[[INNER1:.*]]:2 = "test.inner_before"(%{{.*}}) : (f32) -> (i1, i64)
|
||||||
%2:2 = "test.inner_before"(%inner) : (f32) -> (i1, i64)
|
%2:2 = "test.inner_before"(%inner) : (f32) -> (i1, i64)
|
||||||
// CHECK: cond_br %[[INNER1]]#0,
|
// CHECK: cf.cond_br %[[INNER1]]#0,
|
||||||
// CHECK: ^[[INNER_BEFORE_AFTER:.*]](%[[INNER1]]#1 : i64),
|
// CHECK: ^[[INNER_BEFORE_AFTER:.*]](%[[INNER1]]#1 : i64),
|
||||||
// CHECK: ^[[OUTER_BEFORE_LAST:.*]]
|
// CHECK: ^[[OUTER_BEFORE_LAST:.*]]
|
||||||
scf.condition(%2#0) %2#1 : i64
|
scf.condition(%2#0) %2#1 : i64
|
||||||
|
@ -503,13 +503,13 @@ func @nested_while_ops(%arg0: f32) -> i64 {
|
||||||
^bb0(%arg1: i64):
|
^bb0(%arg1: i64):
|
||||||
// CHECK: %[[INNER2:.*]] = "test.inner_after"(%{{.*}}) : (i64) -> f32
|
// CHECK: %[[INNER2:.*]] = "test.inner_after"(%{{.*}}) : (i64) -> f32
|
||||||
%3 = "test.inner_after"(%arg1) : (i64) -> f32
|
%3 = "test.inner_after"(%arg1) : (i64) -> f32
|
||||||
// CHECK: br ^[[INNER_BEFORE_BEFORE]](%[[INNER2]] : f32)
|
// CHECK: cf.br ^[[INNER_BEFORE_BEFORE]](%[[INNER2]] : f32)
|
||||||
scf.yield %3 : f32
|
scf.yield %3 : f32
|
||||||
}
|
}
|
||||||
// CHECK: ^[[OUTER_BEFORE_LAST]]:
|
// CHECK: ^[[OUTER_BEFORE_LAST]]:
|
||||||
// CHECK: "test.outer_before_post"() : () -> ()
|
// CHECK: "test.outer_before_post"() : () -> ()
|
||||||
"test.outer_before_post"() : () -> ()
|
"test.outer_before_post"() : () -> ()
|
||||||
// CHECK: cond_br %[[OUTER_COND]],
|
// CHECK: cf.cond_br %[[OUTER_COND]],
|
||||||
// CHECK: ^[[OUTER_AFTER:.*]](%[[INNER1]]#1 : i64),
|
// CHECK: ^[[OUTER_AFTER:.*]](%[[INNER1]]#1 : i64),
|
||||||
// CHECK: ^[[CONTINUATION:.*]]
|
// CHECK: ^[[CONTINUATION:.*]]
|
||||||
scf.condition(%cond) %1 : i64
|
scf.condition(%cond) %1 : i64
|
||||||
|
@ -518,12 +518,12 @@ func @nested_while_ops(%arg0: f32) -> i64 {
|
||||||
^bb2(%arg2: i64):
|
^bb2(%arg2: i64):
|
||||||
// CHECK: "test.outer_after_pre"(%{{.*}}) : (i64) -> ()
|
// CHECK: "test.outer_after_pre"(%{{.*}}) : (i64) -> ()
|
||||||
"test.outer_after_pre"(%arg2) : (i64) -> ()
|
"test.outer_after_pre"(%arg2) : (i64) -> ()
|
||||||
// CHECK: br ^[[INNER_AFTER_BEFORE:.*]](%{{.*}} : i64)
|
// CHECK: cf.br ^[[INNER_AFTER_BEFORE:.*]](%{{.*}} : i64)
|
||||||
%4 = scf.while(%inner = %arg2) : (i64) -> f32 {
|
%4 = scf.while(%inner = %arg2) : (i64) -> f32 {
|
||||||
// CHECK: ^[[INNER_AFTER_BEFORE]](%{{.*}}: i64):
|
// CHECK: ^[[INNER_AFTER_BEFORE]](%{{.*}}: i64):
|
||||||
// CHECK: %[[INNER3:.*]]:2 = "test.inner2_before"(%{{.*}}) : (i64) -> (i1, f32)
|
// CHECK: %[[INNER3:.*]]:2 = "test.inner2_before"(%{{.*}}) : (i64) -> (i1, f32)
|
||||||
%5:2 = "test.inner2_before"(%inner) : (i64) -> (i1, f32)
|
%5:2 = "test.inner2_before"(%inner) : (i64) -> (i1, f32)
|
||||||
// CHECK: cond_br %[[INNER3]]#0,
|
// CHECK: cf.cond_br %[[INNER3]]#0,
|
||||||
// CHECK: ^[[INNER_AFTER_AFTER:.*]](%[[INNER3]]#1 : f32),
|
// CHECK: ^[[INNER_AFTER_AFTER:.*]](%[[INNER3]]#1 : f32),
|
||||||
// CHECK: ^[[OUTER_AFTER_LAST:.*]]
|
// CHECK: ^[[OUTER_AFTER_LAST:.*]]
|
||||||
scf.condition(%5#0) %5#1 : f32
|
scf.condition(%5#0) %5#1 : f32
|
||||||
|
@ -532,13 +532,13 @@ func @nested_while_ops(%arg0: f32) -> i64 {
|
||||||
^bb3(%arg3: f32):
|
^bb3(%arg3: f32):
|
||||||
// CHECK: %{{.*}} = "test.inner2_after"(%{{.*}}) : (f32) -> i64
|
// CHECK: %{{.*}} = "test.inner2_after"(%{{.*}}) : (f32) -> i64
|
||||||
%6 = "test.inner2_after"(%arg3) : (f32) -> i64
|
%6 = "test.inner2_after"(%arg3) : (f32) -> i64
|
||||||
// CHECK: br ^[[INNER_AFTER_BEFORE]](%{{.*}} : i64)
|
// CHECK: cf.br ^[[INNER_AFTER_BEFORE]](%{{.*}} : i64)
|
||||||
scf.yield %6 : i64
|
scf.yield %6 : i64
|
||||||
}
|
}
|
||||||
// CHECK: ^[[OUTER_AFTER_LAST]]:
|
// CHECK: ^[[OUTER_AFTER_LAST]]:
|
||||||
// CHECK: "test.outer_after_post"() : () -> ()
|
// CHECK: "test.outer_after_post"() : () -> ()
|
||||||
"test.outer_after_post"() : () -> ()
|
"test.outer_after_post"() : () -> ()
|
||||||
// CHECK: br ^[[OUTER_BEFORE]](%[[INNER3]]#1 : f32)
|
// CHECK: cf.br ^[[OUTER_BEFORE]](%[[INNER3]]#1 : f32)
|
||||||
scf.yield %4 : f32
|
scf.yield %4 : f32
|
||||||
}
|
}
|
||||||
// CHECK: ^[[CONTINUATION]]:
|
// CHECK: ^[[CONTINUATION]]:
|
||||||
|
@ -549,27 +549,27 @@ func @nested_while_ops(%arg0: f32) -> i64 {
|
||||||
// CHECK-LABEL: @ifs_in_parallel
|
// CHECK-LABEL: @ifs_in_parallel
|
||||||
// CHECK: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: i1, %[[ARG4:.*]]: i1)
|
// CHECK: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: i1, %[[ARG4:.*]]: i1)
|
||||||
func @ifs_in_parallel(%arg1: index, %arg2: index, %arg3: index, %arg4: i1, %arg5: i1) {
|
func @ifs_in_parallel(%arg1: index, %arg2: index, %arg3: index, %arg4: i1, %arg5: i1) {
|
||||||
// CHECK: br ^[[LOOP_LATCH:.*]](%[[ARG0]] : index)
|
// CHECK: cf.br ^[[LOOP_LATCH:.*]](%[[ARG0]] : index)
|
||||||
// CHECK: ^[[LOOP_LATCH]](%[[LOOP_IV:.*]]: index):
|
// CHECK: ^[[LOOP_LATCH]](%[[LOOP_IV:.*]]: index):
|
||||||
// CHECK: %[[LOOP_COND:.*]] = arith.cmpi slt, %[[LOOP_IV]], %[[ARG1]] : index
|
// CHECK: %[[LOOP_COND:.*]] = arith.cmpi slt, %[[LOOP_IV]], %[[ARG1]] : index
|
||||||
// CHECK: cond_br %[[LOOP_COND]], ^[[LOOP_BODY:.*]], ^[[LOOP_CONT:.*]]
|
// CHECK: cf.cond_br %[[LOOP_COND]], ^[[LOOP_BODY:.*]], ^[[LOOP_CONT:.*]]
|
||||||
// CHECK: ^[[LOOP_BODY]]:
|
// CHECK: ^[[LOOP_BODY]]:
|
||||||
// CHECK: cond_br %[[ARG3]], ^[[IF1_THEN:.*]], ^[[IF1_CONT:.*]]
|
// CHECK: cf.cond_br %[[ARG3]], ^[[IF1_THEN:.*]], ^[[IF1_CONT:.*]]
|
||||||
// CHECK: ^[[IF1_THEN]]:
|
// CHECK: ^[[IF1_THEN]]:
|
||||||
// CHECK: cond_br %[[ARG4]], ^[[IF2_THEN:.*]], ^[[IF2_ELSE:.*]]
|
// CHECK: cf.cond_br %[[ARG4]], ^[[IF2_THEN:.*]], ^[[IF2_ELSE:.*]]
|
||||||
// CHECK: ^[[IF2_THEN]]:
|
// CHECK: ^[[IF2_THEN]]:
|
||||||
// CHECK: %{{.*}} = "test.if2"() : () -> index
|
// CHECK: %{{.*}} = "test.if2"() : () -> index
|
||||||
// CHECK: br ^[[IF2_MERGE:.*]](%{{.*}} : index)
|
// CHECK: cf.br ^[[IF2_MERGE:.*]](%{{.*}} : index)
|
||||||
// CHECK: ^[[IF2_ELSE]]:
|
// CHECK: ^[[IF2_ELSE]]:
|
||||||
// CHECK: %{{.*}} = "test.else2"() : () -> index
|
// CHECK: %{{.*}} = "test.else2"() : () -> index
|
||||||
// CHECK: br ^[[IF2_MERGE]](%{{.*}} : index)
|
// CHECK: cf.br ^[[IF2_MERGE]](%{{.*}} : index)
|
||||||
// CHECK: ^[[IF2_MERGE]](%{{.*}}: index):
|
// CHECK: ^[[IF2_MERGE]](%{{.*}}: index):
|
||||||
// CHECK: br ^[[IF2_CONT:.*]]
|
// CHECK: cf.br ^[[IF2_CONT:.*]]
|
||||||
// CHECK: ^[[IF2_CONT]]:
|
// CHECK: ^[[IF2_CONT]]:
|
||||||
// CHECK: br ^[[IF1_CONT]]
|
// CHECK: cf.br ^[[IF1_CONT]]
|
||||||
// CHECK: ^[[IF1_CONT]]:
|
// CHECK: ^[[IF1_CONT]]:
|
||||||
// CHECK: %{{.*}} = arith.addi %[[LOOP_IV]], %[[ARG2]] : index
|
// CHECK: %{{.*}} = arith.addi %[[LOOP_IV]], %[[ARG2]] : index
|
||||||
// CHECK: br ^[[LOOP_LATCH]](%{{.*}} : index)
|
// CHECK: cf.br ^[[LOOP_LATCH]](%{{.*}} : index)
|
||||||
scf.parallel (%i) = (%arg1) to (%arg2) step (%arg3) {
|
scf.parallel (%i) = (%arg1) to (%arg2) step (%arg3) {
|
||||||
scf.if %arg4 {
|
scf.if %arg4 {
|
||||||
%0 = scf.if %arg5 -> (index) {
|
%0 = scf.if %arg5 -> (index) {
|
||||||
|
@ -593,7 +593,7 @@ func @func_execute_region_elim_multi_yield() {
|
||||||
"test.foo"() : () -> ()
|
"test.foo"() : () -> ()
|
||||||
%v = scf.execute_region -> i64 {
|
%v = scf.execute_region -> i64 {
|
||||||
%c = "test.cmp"() : () -> i1
|
%c = "test.cmp"() : () -> i1
|
||||||
cond_br %c, ^bb2, ^bb3
|
cf.cond_br %c, ^bb2, ^bb3
|
||||||
^bb2:
|
^bb2:
|
||||||
%x = "test.val1"() : () -> i64
|
%x = "test.val1"() : () -> i64
|
||||||
scf.yield %x : i64
|
scf.yield %x : i64
|
||||||
|
@ -607,16 +607,16 @@ func @func_execute_region_elim_multi_yield() {
|
||||||
|
|
||||||
// CHECK-NOT: execute_region
|
// CHECK-NOT: execute_region
|
||||||
// CHECK: "test.foo"
|
// CHECK: "test.foo"
|
||||||
// CHECK: br ^[[rentry:.+]]
|
// CHECK: cf.br ^[[rentry:.+]]
|
||||||
// CHECK: ^[[rentry]]
|
// CHECK: ^[[rentry]]
|
||||||
// CHECK: %[[cmp:.+]] = "test.cmp"
|
// CHECK: %[[cmp:.+]] = "test.cmp"
|
||||||
// CHECK: cond_br %[[cmp]], ^[[bb1:.+]], ^[[bb2:.+]]
|
// CHECK: cf.cond_br %[[cmp]], ^[[bb1:.+]], ^[[bb2:.+]]
|
||||||
// CHECK: ^[[bb1]]:
|
// CHECK: ^[[bb1]]:
|
||||||
// CHECK: %[[x:.+]] = "test.val1"
|
// CHECK: %[[x:.+]] = "test.val1"
|
||||||
// CHECK: br ^[[bb3:.+]](%[[x]] : i64)
|
// CHECK: cf.br ^[[bb3:.+]](%[[x]] : i64)
|
||||||
// CHECK: ^[[bb2]]:
|
// CHECK: ^[[bb2]]:
|
||||||
// CHECK: %[[y:.+]] = "test.val2"
|
// CHECK: %[[y:.+]] = "test.val2"
|
||||||
// CHECK: br ^[[bb3]](%[[y:.+]] : i64)
|
// CHECK: cf.br ^[[bb3]](%[[y:.+]] : i64)
|
||||||
// CHECK: ^[[bb3]](%[[z:.+]]: i64):
|
// CHECK: ^[[bb3]](%[[z:.+]]: i64):
|
||||||
// CHECK: "test.bar"(%[[z]])
|
// CHECK: "test.bar"(%[[z]])
|
||||||
// CHECK: return
|
// CHECK: return
|
|
@ -6,7 +6,7 @@
|
||||||
// CHECK-SAME: %[[RHS:.*]]: tensor<?xindex>) -> !shape.witness {
|
// CHECK-SAME: %[[RHS:.*]]: tensor<?xindex>) -> !shape.witness {
|
||||||
// CHECK: %[[RET:.*]] = shape.const_witness true
|
// CHECK: %[[RET:.*]] = shape.const_witness true
|
||||||
// CHECK: %[[BROADCAST_IS_VALID:.*]] = shape.is_broadcastable %[[LHS]], %[[RHS]]
|
// CHECK: %[[BROADCAST_IS_VALID:.*]] = shape.is_broadcastable %[[LHS]], %[[RHS]]
|
||||||
// CHECK: assert %[[BROADCAST_IS_VALID]], "required broadcastable shapes"
|
// CHECK: cf.assert %[[BROADCAST_IS_VALID]], "required broadcastable shapes"
|
||||||
// CHECK: return %[[RET]] : !shape.witness
|
// CHECK: return %[[RET]] : !shape.witness
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
func @cstr_broadcastable(%arg0: tensor<?xindex>, %arg1: tensor<?xindex>) -> !shape.witness {
|
func @cstr_broadcastable(%arg0: tensor<?xindex>, %arg1: tensor<?xindex>) -> !shape.witness {
|
||||||
|
@ -19,7 +19,7 @@ func @cstr_broadcastable(%arg0: tensor<?xindex>, %arg1: tensor<?xindex>) -> !sha
|
||||||
// CHECK-SAME: %[[RHS:.*]]: tensor<?xindex>) -> !shape.witness {
|
// CHECK-SAME: %[[RHS:.*]]: tensor<?xindex>) -> !shape.witness {
|
||||||
// CHECK: %[[RET:.*]] = shape.const_witness true
|
// CHECK: %[[RET:.*]] = shape.const_witness true
|
||||||
// CHECK: %[[EQUAL_IS_VALID:.*]] = shape.shape_eq %[[LHS]], %[[RHS]]
|
// CHECK: %[[EQUAL_IS_VALID:.*]] = shape.shape_eq %[[LHS]], %[[RHS]]
|
||||||
// CHECK: assert %[[EQUAL_IS_VALID]], "required equal shapes"
|
// CHECK: cf.assert %[[EQUAL_IS_VALID]], "required equal shapes"
|
||||||
// CHECK: return %[[RET]] : !shape.witness
|
// CHECK: return %[[RET]] : !shape.witness
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
func @cstr_eq(%arg0: tensor<?xindex>, %arg1: tensor<?xindex>) -> !shape.witness {
|
func @cstr_eq(%arg0: tensor<?xindex>, %arg1: tensor<?xindex>) -> !shape.witness {
|
||||||
|
@ -30,7 +30,7 @@ func @cstr_eq(%arg0: tensor<?xindex>, %arg1: tensor<?xindex>) -> !shape.witness
|
||||||
// CHECK-LABEL: func @cstr_require
|
// CHECK-LABEL: func @cstr_require
|
||||||
func @cstr_require(%arg0: i1) -> !shape.witness {
|
func @cstr_require(%arg0: i1) -> !shape.witness {
|
||||||
// CHECK: %[[RET:.*]] = shape.const_witness true
|
// CHECK: %[[RET:.*]] = shape.const_witness true
|
||||||
// CHECK: assert %arg0, "msg"
|
// CHECK: cf.assert %arg0, "msg"
|
||||||
// CHECK: return %[[RET]]
|
// CHECK: return %[[RET]]
|
||||||
%witness = shape.cstr_require %arg0, "msg"
|
%witness = shape.cstr_require %arg0, "msg"
|
||||||
return %witness : !shape.witness
|
return %witness : !shape.witness
|
||||||
|
|
|
@ -29,7 +29,7 @@ func private @memref_call_conv_nested(%arg0: (memref<?xf32>) -> ())
|
||||||
//CHECK-LABEL: llvm.func @pass_through(%arg0: !llvm.ptr<func<void ()>>) -> !llvm.ptr<func<void ()>> {
|
//CHECK-LABEL: llvm.func @pass_through(%arg0: !llvm.ptr<func<void ()>>) -> !llvm.ptr<func<void ()>> {
|
||||||
func @pass_through(%arg0: () -> ()) -> (() -> ()) {
|
func @pass_through(%arg0: () -> ()) -> (() -> ()) {
|
||||||
// CHECK-NEXT: llvm.br ^bb1(%arg0 : !llvm.ptr<func<void ()>>)
|
// CHECK-NEXT: llvm.br ^bb1(%arg0 : !llvm.ptr<func<void ()>>)
|
||||||
br ^bb1(%arg0 : () -> ())
|
cf.br ^bb1(%arg0 : () -> ())
|
||||||
|
|
||||||
//CHECK-NEXT: ^bb1(%0: !llvm.ptr<func<void ()>>):
|
//CHECK-NEXT: ^bb1(%0: !llvm.ptr<func<void ()>>):
|
||||||
^bb1(%bbarg: () -> ()):
|
^bb1(%bbarg: () -> ()):
|
||||||
|
|
|
@ -109,17 +109,17 @@ func @loop_carried(%arg0 : index, %arg1 : index, %arg2 : index, %base0 : !base_t
|
||||||
// This test checks that in the BAREPTR case, the branch arguments only forward the descriptor.
|
// This test checks that in the BAREPTR case, the branch arguments only forward the descriptor.
|
||||||
// This test was lowered from a simple scf.for that swaps 2 memref iter_args.
|
// This test was lowered from a simple scf.for that swaps 2 memref iter_args.
|
||||||
// BAREPTR: llvm.br ^bb1(%{{.*}}, %{{.*}}, %{{.*}} : i64, !llvm.struct<(ptr<i32, 201>, ptr<i32, 201>, i64, array<1 x i64>, array<1 x i64>)>, !llvm.struct<(ptr<i32, 201>, ptr<i32, 201>, i64, array<1 x i64>, array<1 x i64>)>)
|
// BAREPTR: llvm.br ^bb1(%{{.*}}, %{{.*}}, %{{.*}} : i64, !llvm.struct<(ptr<i32, 201>, ptr<i32, 201>, i64, array<1 x i64>, array<1 x i64>)>, !llvm.struct<(ptr<i32, 201>, ptr<i32, 201>, i64, array<1 x i64>, array<1 x i64>)>)
|
||||||
br ^bb1(%arg0, %base0, %base1 : index, memref<64xi32, 201>, memref<64xi32, 201>)
|
cf.br ^bb1(%arg0, %base0, %base1 : index, memref<64xi32, 201>, memref<64xi32, 201>)
|
||||||
|
|
||||||
// BAREPTR-NEXT: ^bb1
|
// BAREPTR-NEXT: ^bb1
|
||||||
// BAREPTR-NEXT: llvm.icmp
|
// BAREPTR-NEXT: llvm.icmp
|
||||||
// BAREPTR-NEXT: llvm.cond_br %{{.*}}, ^bb2, ^bb3
|
// BAREPTR-NEXT: llvm.cond_br %{{.*}}, ^bb2, ^bb3
|
||||||
^bb1(%0: index, %1: memref<64xi32, 201>, %2: memref<64xi32, 201>): // 2 preds: ^bb0, ^bb2
|
^bb1(%0: index, %1: memref<64xi32, 201>, %2: memref<64xi32, 201>): // 2 preds: ^bb0, ^bb2
|
||||||
%3 = arith.cmpi slt, %0, %arg1 : index
|
%3 = arith.cmpi slt, %0, %arg1 : index
|
||||||
cond_br %3, ^bb2, ^bb3
|
cf.cond_br %3, ^bb2, ^bb3
|
||||||
^bb2: // pred: ^bb1
|
^bb2: // pred: ^bb1
|
||||||
%4 = arith.addi %0, %arg2 : index
|
%4 = arith.addi %0, %arg2 : index
|
||||||
br ^bb1(%4, %2, %1 : index, memref<64xi32, 201>, memref<64xi32, 201>)
|
cf.br ^bb1(%4, %2, %1 : index, memref<64xi32, 201>, memref<64xi32, 201>)
|
||||||
^bb3: // pred: ^bb1
|
^bb3: // pred: ^bb1
|
||||||
return %1, %2 : memref<64xi32, 201>, memref<64xi32, 201>
|
return %1, %2 : memref<64xi32, 201>, memref<64xi32, 201>
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,7 +18,7 @@ func @simple_loop() {
|
||||||
^bb0:
|
^bb0:
|
||||||
// CHECK-NEXT: llvm.br ^bb1
|
// CHECK-NEXT: llvm.br ^bb1
|
||||||
// CHECK32-NEXT: llvm.br ^bb1
|
// CHECK32-NEXT: llvm.br ^bb1
|
||||||
br ^bb1
|
cf.br ^bb1
|
||||||
|
|
||||||
// CHECK-NEXT: ^bb1: // pred: ^bb0
|
// CHECK-NEXT: ^bb1: // pred: ^bb0
|
||||||
// CHECK-NEXT: {{.*}} = llvm.mlir.constant(1 : index) : i64
|
// CHECK-NEXT: {{.*}} = llvm.mlir.constant(1 : index) : i64
|
||||||
|
@ -31,7 +31,7 @@ func @simple_loop() {
|
||||||
^bb1: // pred: ^bb0
|
^bb1: // pred: ^bb0
|
||||||
%c1 = arith.constant 1 : index
|
%c1 = arith.constant 1 : index
|
||||||
%c42 = arith.constant 42 : index
|
%c42 = arith.constant 42 : index
|
||||||
br ^bb2(%c1 : index)
|
cf.br ^bb2(%c1 : index)
|
||||||
|
|
||||||
// CHECK: ^bb2({{.*}}: i64): // 2 preds: ^bb1, ^bb3
|
// CHECK: ^bb2({{.*}}: i64): // 2 preds: ^bb1, ^bb3
|
||||||
// CHECK-NEXT: {{.*}} = llvm.icmp "slt" {{.*}}, {{.*}} : i64
|
// CHECK-NEXT: {{.*}} = llvm.icmp "slt" {{.*}}, {{.*}} : i64
|
||||||
|
@ -41,7 +41,7 @@ func @simple_loop() {
|
||||||
// CHECK32-NEXT: llvm.cond_br {{.*}}, ^bb3, ^bb4
|
// CHECK32-NEXT: llvm.cond_br {{.*}}, ^bb3, ^bb4
|
||||||
^bb2(%0: index): // 2 preds: ^bb1, ^bb3
|
^bb2(%0: index): // 2 preds: ^bb1, ^bb3
|
||||||
%1 = arith.cmpi slt, %0, %c42 : index
|
%1 = arith.cmpi slt, %0, %c42 : index
|
||||||
cond_br %1, ^bb3, ^bb4
|
cf.cond_br %1, ^bb3, ^bb4
|
||||||
|
|
||||||
// CHECK: ^bb3: // pred: ^bb2
|
// CHECK: ^bb3: // pred: ^bb2
|
||||||
// CHECK-NEXT: llvm.call @body({{.*}}) : (i64) -> ()
|
// CHECK-NEXT: llvm.call @body({{.*}}) : (i64) -> ()
|
||||||
|
@ -57,7 +57,7 @@ func @simple_loop() {
|
||||||
call @body(%0) : (index) -> ()
|
call @body(%0) : (index) -> ()
|
||||||
%c1_0 = arith.constant 1 : index
|
%c1_0 = arith.constant 1 : index
|
||||||
%2 = arith.addi %0, %c1_0 : index
|
%2 = arith.addi %0, %c1_0 : index
|
||||||
br ^bb2(%2 : index)
|
cf.br ^bb2(%2 : index)
|
||||||
|
|
||||||
// CHECK: ^bb4: // pred: ^bb2
|
// CHECK: ^bb4: // pred: ^bb2
|
||||||
// CHECK-NEXT: llvm.return
|
// CHECK-NEXT: llvm.return
|
||||||
|
@ -111,7 +111,7 @@ func private @other(index, i32) -> i32
|
||||||
func @func_args(i32, i32) -> i32 {
|
func @func_args(i32, i32) -> i32 {
|
||||||
^bb0(%arg0: i32, %arg1: i32):
|
^bb0(%arg0: i32, %arg1: i32):
|
||||||
%c0_i32 = arith.constant 0 : i32
|
%c0_i32 = arith.constant 0 : i32
|
||||||
br ^bb1
|
cf.br ^bb1
|
||||||
|
|
||||||
// CHECK-NEXT: ^bb1: // pred: ^bb0
|
// CHECK-NEXT: ^bb1: // pred: ^bb0
|
||||||
// CHECK-NEXT: {{.*}} = llvm.mlir.constant(0 : index) : i64
|
// CHECK-NEXT: {{.*}} = llvm.mlir.constant(0 : index) : i64
|
||||||
|
@ -124,7 +124,7 @@ func @func_args(i32, i32) -> i32 {
|
||||||
^bb1: // pred: ^bb0
|
^bb1: // pred: ^bb0
|
||||||
%c0 = arith.constant 0 : index
|
%c0 = arith.constant 0 : index
|
||||||
%c42 = arith.constant 42 : index
|
%c42 = arith.constant 42 : index
|
||||||
br ^bb2(%c0 : index)
|
cf.br ^bb2(%c0 : index)
|
||||||
|
|
||||||
// CHECK-NEXT: ^bb2({{.*}}: i64): // 2 preds: ^bb1, ^bb3
|
// CHECK-NEXT: ^bb2({{.*}}: i64): // 2 preds: ^bb1, ^bb3
|
||||||
// CHECK-NEXT: {{.*}} = llvm.icmp "slt" {{.*}}, {{.*}} : i64
|
// CHECK-NEXT: {{.*}} = llvm.icmp "slt" {{.*}}, {{.*}} : i64
|
||||||
|
@ -134,7 +134,7 @@ func @func_args(i32, i32) -> i32 {
|
||||||
// CHECK32-NEXT: llvm.cond_br {{.*}}, ^bb3, ^bb4
|
// CHECK32-NEXT: llvm.cond_br {{.*}}, ^bb3, ^bb4
|
||||||
^bb2(%0: index): // 2 preds: ^bb1, ^bb3
|
^bb2(%0: index): // 2 preds: ^bb1, ^bb3
|
||||||
%1 = arith.cmpi slt, %0, %c42 : index
|
%1 = arith.cmpi slt, %0, %c42 : index
|
||||||
cond_br %1, ^bb3, ^bb4
|
cf.cond_br %1, ^bb3, ^bb4
|
||||||
|
|
||||||
// CHECK-NEXT: ^bb3: // pred: ^bb2
|
// CHECK-NEXT: ^bb3: // pred: ^bb2
|
||||||
// CHECK-NEXT: {{.*}} = llvm.call @body_args({{.*}}) : (i64) -> i64
|
// CHECK-NEXT: {{.*}} = llvm.call @body_args({{.*}}) : (i64) -> i64
|
||||||
|
@ -159,7 +159,7 @@ func @func_args(i32, i32) -> i32 {
|
||||||
%5 = call @other(%2, %arg1) : (index, i32) -> i32
|
%5 = call @other(%2, %arg1) : (index, i32) -> i32
|
||||||
%c1 = arith.constant 1 : index
|
%c1 = arith.constant 1 : index
|
||||||
%6 = arith.addi %0, %c1 : index
|
%6 = arith.addi %0, %c1 : index
|
||||||
br ^bb2(%6 : index)
|
cf.br ^bb2(%6 : index)
|
||||||
|
|
||||||
// CHECK-NEXT: ^bb4: // pred: ^bb2
|
// CHECK-NEXT: ^bb4: // pred: ^bb2
|
||||||
// CHECK-NEXT: {{.*}} = llvm.mlir.constant(0 : index) : i64
|
// CHECK-NEXT: {{.*}} = llvm.mlir.constant(0 : index) : i64
|
||||||
|
@ -191,7 +191,7 @@ func private @post(index)
|
||||||
// CHECK-NEXT: llvm.br ^bb1
|
// CHECK-NEXT: llvm.br ^bb1
|
||||||
func @imperfectly_nested_loops() {
|
func @imperfectly_nested_loops() {
|
||||||
^bb0:
|
^bb0:
|
||||||
br ^bb1
|
cf.br ^bb1
|
||||||
|
|
||||||
// CHECK-NEXT: ^bb1: // pred: ^bb0
|
// CHECK-NEXT: ^bb1: // pred: ^bb0
|
||||||
// CHECK-NEXT: {{.*}} = llvm.mlir.constant(0 : index) : i64
|
// CHECK-NEXT: {{.*}} = llvm.mlir.constant(0 : index) : i64
|
||||||
|
@ -200,21 +200,21 @@ func @imperfectly_nested_loops() {
|
||||||
^bb1: // pred: ^bb0
|
^bb1: // pred: ^bb0
|
||||||
%c0 = arith.constant 0 : index
|
%c0 = arith.constant 0 : index
|
||||||
%c42 = arith.constant 42 : index
|
%c42 = arith.constant 42 : index
|
||||||
br ^bb2(%c0 : index)
|
cf.br ^bb2(%c0 : index)
|
||||||
|
|
||||||
// CHECK-NEXT: ^bb2({{.*}}: i64): // 2 preds: ^bb1, ^bb7
|
// CHECK-NEXT: ^bb2({{.*}}: i64): // 2 preds: ^bb1, ^bb7
|
||||||
// CHECK-NEXT: {{.*}} = llvm.icmp "slt" {{.*}}, {{.*}} : i64
|
// CHECK-NEXT: {{.*}} = llvm.icmp "slt" {{.*}}, {{.*}} : i64
|
||||||
// CHECK-NEXT: llvm.cond_br {{.*}}, ^bb3, ^bb8
|
// CHECK-NEXT: llvm.cond_br {{.*}}, ^bb3, ^bb8
|
||||||
^bb2(%0: index): // 2 preds: ^bb1, ^bb7
|
^bb2(%0: index): // 2 preds: ^bb1, ^bb7
|
||||||
%1 = arith.cmpi slt, %0, %c42 : index
|
%1 = arith.cmpi slt, %0, %c42 : index
|
||||||
cond_br %1, ^bb3, ^bb8
|
cf.cond_br %1, ^bb3, ^bb8
|
||||||
|
|
||||||
// CHECK-NEXT: ^bb3:
|
// CHECK-NEXT: ^bb3:
|
||||||
// CHECK-NEXT: llvm.call @pre({{.*}}) : (i64) -> ()
|
// CHECK-NEXT: llvm.call @pre({{.*}}) : (i64) -> ()
|
||||||
// CHECK-NEXT: llvm.br ^bb4
|
// CHECK-NEXT: llvm.br ^bb4
|
||||||
^bb3: // pred: ^bb2
|
^bb3: // pred: ^bb2
|
||||||
call @pre(%0) : (index) -> ()
|
call @pre(%0) : (index) -> ()
|
||||||
br ^bb4
|
cf.br ^bb4
|
||||||
|
|
||||||
// CHECK-NEXT: ^bb4: // pred: ^bb3
|
// CHECK-NEXT: ^bb4: // pred: ^bb3
|
||||||
// CHECK-NEXT: {{.*}} = llvm.mlir.constant(7 : index) : i64
|
// CHECK-NEXT: {{.*}} = llvm.mlir.constant(7 : index) : i64
|
||||||
|
@ -223,14 +223,14 @@ func @imperfectly_nested_loops() {
|
||||||
^bb4: // pred: ^bb3
|
^bb4: // pred: ^bb3
|
||||||
%c7 = arith.constant 7 : index
|
%c7 = arith.constant 7 : index
|
||||||
%c56 = arith.constant 56 : index
|
%c56 = arith.constant 56 : index
|
||||||
br ^bb5(%c7 : index)
|
cf.br ^bb5(%c7 : index)
|
||||||
|
|
||||||
// CHECK-NEXT: ^bb5({{.*}}: i64): // 2 preds: ^bb4, ^bb6
|
// CHECK-NEXT: ^bb5({{.*}}: i64): // 2 preds: ^bb4, ^bb6
|
||||||
// CHECK-NEXT: {{.*}} = llvm.icmp "slt" {{.*}}, {{.*}} : i64
|
// CHECK-NEXT: {{.*}} = llvm.icmp "slt" {{.*}}, {{.*}} : i64
|
||||||
// CHECK-NEXT: llvm.cond_br {{.*}}, ^bb6, ^bb7
|
// CHECK-NEXT: llvm.cond_br {{.*}}, ^bb6, ^bb7
|
||||||
^bb5(%2: index): // 2 preds: ^bb4, ^bb6
|
^bb5(%2: index): // 2 preds: ^bb4, ^bb6
|
||||||
%3 = arith.cmpi slt, %2, %c56 : index
|
%3 = arith.cmpi slt, %2, %c56 : index
|
||||||
cond_br %3, ^bb6, ^bb7
|
cf.cond_br %3, ^bb6, ^bb7
|
||||||
|
|
||||||
// CHECK-NEXT: ^bb6: // pred: ^bb5
|
// CHECK-NEXT: ^bb6: // pred: ^bb5
|
||||||
// CHECK-NEXT: llvm.call @body2({{.*}}, {{.*}}) : (i64, i64) -> ()
|
// CHECK-NEXT: llvm.call @body2({{.*}}, {{.*}}) : (i64, i64) -> ()
|
||||||
|
@ -241,7 +241,7 @@ func @imperfectly_nested_loops() {
|
||||||
call @body2(%0, %2) : (index, index) -> ()
|
call @body2(%0, %2) : (index, index) -> ()
|
||||||
%c2 = arith.constant 2 : index
|
%c2 = arith.constant 2 : index
|
||||||
%4 = arith.addi %2, %c2 : index
|
%4 = arith.addi %2, %c2 : index
|
||||||
br ^bb5(%4 : index)
|
cf.br ^bb5(%4 : index)
|
||||||
|
|
||||||
// CHECK-NEXT: ^bb7: // pred: ^bb5
|
// CHECK-NEXT: ^bb7: // pred: ^bb5
|
||||||
// CHECK-NEXT: llvm.call @post({{.*}}) : (i64) -> ()
|
// CHECK-NEXT: llvm.call @post({{.*}}) : (i64) -> ()
|
||||||
|
@ -252,7 +252,7 @@ func @imperfectly_nested_loops() {
|
||||||
call @post(%0) : (index) -> ()
|
call @post(%0) : (index) -> ()
|
||||||
%c1 = arith.constant 1 : index
|
%c1 = arith.constant 1 : index
|
||||||
%5 = arith.addi %0, %c1 : index
|
%5 = arith.addi %0, %c1 : index
|
||||||
br ^bb2(%5 : index)
|
cf.br ^bb2(%5 : index)
|
||||||
|
|
||||||
// CHECK-NEXT: ^bb8: // pred: ^bb2
|
// CHECK-NEXT: ^bb8: // pred: ^bb2
|
||||||
// CHECK-NEXT: llvm.return
|
// CHECK-NEXT: llvm.return
|
||||||
|
@ -316,49 +316,49 @@ func private @body3(index, index)
|
||||||
// CHECK-NEXT: }
|
// CHECK-NEXT: }
|
||||||
func @more_imperfectly_nested_loops() {
|
func @more_imperfectly_nested_loops() {
|
||||||
^bb0:
|
^bb0:
|
||||||
br ^bb1
|
cf.br ^bb1
|
||||||
^bb1: // pred: ^bb0
|
^bb1: // pred: ^bb0
|
||||||
%c0 = arith.constant 0 : index
|
%c0 = arith.constant 0 : index
|
||||||
%c42 = arith.constant 42 : index
|
%c42 = arith.constant 42 : index
|
||||||
br ^bb2(%c0 : index)
|
cf.br ^bb2(%c0 : index)
|
||||||
^bb2(%0: index): // 2 preds: ^bb1, ^bb11
|
^bb2(%0: index): // 2 preds: ^bb1, ^bb11
|
||||||
%1 = arith.cmpi slt, %0, %c42 : index
|
%1 = arith.cmpi slt, %0, %c42 : index
|
||||||
cond_br %1, ^bb3, ^bb12
|
cf.cond_br %1, ^bb3, ^bb12
|
||||||
^bb3: // pred: ^bb2
|
^bb3: // pred: ^bb2
|
||||||
call @pre(%0) : (index) -> ()
|
call @pre(%0) : (index) -> ()
|
||||||
br ^bb4
|
cf.br ^bb4
|
||||||
^bb4: // pred: ^bb3
|
^bb4: // pred: ^bb3
|
||||||
%c7 = arith.constant 7 : index
|
%c7 = arith.constant 7 : index
|
||||||
%c56 = arith.constant 56 : index
|
%c56 = arith.constant 56 : index
|
||||||
br ^bb5(%c7 : index)
|
cf.br ^bb5(%c7 : index)
|
||||||
^bb5(%2: index): // 2 preds: ^bb4, ^bb6
|
^bb5(%2: index): // 2 preds: ^bb4, ^bb6
|
||||||
%3 = arith.cmpi slt, %2, %c56 : index
|
%3 = arith.cmpi slt, %2, %c56 : index
|
||||||
cond_br %3, ^bb6, ^bb7
|
cf.cond_br %3, ^bb6, ^bb7
|
||||||
^bb6: // pred: ^bb5
|
^bb6: // pred: ^bb5
|
||||||
call @body2(%0, %2) : (index, index) -> ()
|
call @body2(%0, %2) : (index, index) -> ()
|
||||||
%c2 = arith.constant 2 : index
|
%c2 = arith.constant 2 : index
|
||||||
%4 = arith.addi %2, %c2 : index
|
%4 = arith.addi %2, %c2 : index
|
||||||
br ^bb5(%4 : index)
|
cf.br ^bb5(%4 : index)
|
||||||
^bb7: // pred: ^bb5
|
^bb7: // pred: ^bb5
|
||||||
call @mid(%0) : (index) -> ()
|
call @mid(%0) : (index) -> ()
|
||||||
br ^bb8
|
cf.br ^bb8
|
||||||
^bb8: // pred: ^bb7
|
^bb8: // pred: ^bb7
|
||||||
%c18 = arith.constant 18 : index
|
%c18 = arith.constant 18 : index
|
||||||
%c37 = arith.constant 37 : index
|
%c37 = arith.constant 37 : index
|
||||||
br ^bb9(%c18 : index)
|
cf.br ^bb9(%c18 : index)
|
||||||
^bb9(%5: index): // 2 preds: ^bb8, ^bb10
|
^bb9(%5: index): // 2 preds: ^bb8, ^bb10
|
||||||
%6 = arith.cmpi slt, %5, %c37 : index
|
%6 = arith.cmpi slt, %5, %c37 : index
|
||||||
cond_br %6, ^bb10, ^bb11
|
cf.cond_br %6, ^bb10, ^bb11
|
||||||
^bb10: // pred: ^bb9
|
^bb10: // pred: ^bb9
|
||||||
call @body3(%0, %5) : (index, index) -> ()
|
call @body3(%0, %5) : (index, index) -> ()
|
||||||
%c3 = arith.constant 3 : index
|
%c3 = arith.constant 3 : index
|
||||||
%7 = arith.addi %5, %c3 : index
|
%7 = arith.addi %5, %c3 : index
|
||||||
br ^bb9(%7 : index)
|
cf.br ^bb9(%7 : index)
|
||||||
^bb11: // pred: ^bb9
|
^bb11: // pred: ^bb9
|
||||||
call @post(%0) : (index) -> ()
|
call @post(%0) : (index) -> ()
|
||||||
%c1 = arith.constant 1 : index
|
%c1 = arith.constant 1 : index
|
||||||
%8 = arith.addi %0, %c1 : index
|
%8 = arith.addi %0, %c1 : index
|
||||||
br ^bb2(%8 : index)
|
cf.br ^bb2(%8 : index)
|
||||||
^bb12: // pred: ^bb2
|
^bb12: // pred: ^bb2
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -432,7 +432,7 @@ func @dfs_block_order(%arg0: i32) -> (i32) {
|
||||||
// CHECK-NEXT: %[[CST:.*]] = llvm.mlir.constant(42 : i32) : i32
|
// CHECK-NEXT: %[[CST:.*]] = llvm.mlir.constant(42 : i32) : i32
|
||||||
%0 = arith.constant 42 : i32
|
%0 = arith.constant 42 : i32
|
||||||
// CHECK-NEXT: llvm.br ^bb2
|
// CHECK-NEXT: llvm.br ^bb2
|
||||||
br ^bb2
|
cf.br ^bb2
|
||||||
|
|
||||||
// CHECK-NEXT: ^bb1:
|
// CHECK-NEXT: ^bb1:
|
||||||
// CHECK-NEXT: %[[ADD:.*]] = llvm.add %arg0, %[[CST]] : i32
|
// CHECK-NEXT: %[[ADD:.*]] = llvm.add %arg0, %[[CST]] : i32
|
||||||
|
@ -444,7 +444,7 @@ func @dfs_block_order(%arg0: i32) -> (i32) {
|
||||||
// CHECK-NEXT: ^bb2:
|
// CHECK-NEXT: ^bb2:
|
||||||
^bb2:
|
^bb2:
|
||||||
// CHECK-NEXT: llvm.br ^bb1
|
// CHECK-NEXT: llvm.br ^bb1
|
||||||
br ^bb1
|
cf.br ^bb1
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
@ -469,7 +469,7 @@ func @floorf(%arg0 : f32) {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// Lowers `assert` to a function call to `abort` if the assertion is violated.
|
// Lowers `cf.assert` to a function call to `abort` if the assertion is violated.
|
||||||
// CHECK: llvm.func @abort()
|
// CHECK: llvm.func @abort()
|
||||||
// CHECK-LABEL: @assert_test_function
|
// CHECK-LABEL: @assert_test_function
|
||||||
// CHECK-SAME: (%[[ARG:.*]]: i1)
|
// CHECK-SAME: (%[[ARG:.*]]: i1)
|
||||||
|
@ -480,7 +480,7 @@ func @assert_test_function(%arg : i1) {
|
||||||
// CHECK: ^[[FAILURE_BLOCK]]:
|
// CHECK: ^[[FAILURE_BLOCK]]:
|
||||||
// CHECK: llvm.call @abort() : () -> ()
|
// CHECK: llvm.call @abort() : () -> ()
|
||||||
// CHECK: llvm.unreachable
|
// CHECK: llvm.unreachable
|
||||||
assert %arg, "Computer says no"
|
cf.assert %arg, "Computer says no"
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -514,8 +514,8 @@ func @fmaf(%arg0: f32, %arg1: vector<4xf32>) {
|
||||||
|
|
||||||
// CHECK-LABEL: func @switchi8(
|
// CHECK-LABEL: func @switchi8(
|
||||||
func @switchi8(%arg0 : i8) -> i32 {
|
func @switchi8(%arg0 : i8) -> i32 {
|
||||||
switch %arg0 : i8, [
|
cf.switch %arg0 : i8, [
|
||||||
default: ^bb1,
|
default: ^bb1,
|
||||||
42: ^bb1,
|
42: ^bb1,
|
||||||
43: ^bb3
|
43: ^bb3
|
||||||
]
|
]
|
||||||
|
|
|
@ -900,45 +900,3 @@ func @tensor_extract_constant(%a : index, %b: index, %c: index) -> i32 {
|
||||||
// CHECK: spv.ReturnValue %[[VAL]]
|
// CHECK: spv.ReturnValue %[[VAL]]
|
||||||
return %extract : i32
|
return %extract : i32
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// std.br, std.cond_br
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
module attributes {
|
|
||||||
spv.target_env = #spv.target_env<#spv.vce<v1.0, [], []>, {}>
|
|
||||||
} {
|
|
||||||
|
|
||||||
// CHECK-LABEL: func @simple_loop
|
|
||||||
func @simple_loop(index, index, index) {
|
|
||||||
^bb0(%begin : index, %end : index, %step : index):
|
|
||||||
// CHECK-NEXT: spv.Branch ^bb1
|
|
||||||
br ^bb1
|
|
||||||
|
|
||||||
// CHECK-NEXT: ^bb1: // pred: ^bb0
|
|
||||||
// CHECK-NEXT: spv.Branch ^bb2({{.*}} : i32)
|
|
||||||
^bb1: // pred: ^bb0
|
|
||||||
br ^bb2(%begin : index)
|
|
||||||
|
|
||||||
// CHECK: ^bb2({{.*}}: i32): // 2 preds: ^bb1, ^bb3
|
|
||||||
// CHECK-NEXT: {{.*}} = spv.SLessThan {{.*}}, {{.*}} : i32
|
|
||||||
// CHECK-NEXT: spv.BranchConditional {{.*}}, ^bb3, ^bb4
|
|
||||||
^bb2(%0: index): // 2 preds: ^bb1, ^bb3
|
|
||||||
%1 = arith.cmpi slt, %0, %end : index
|
|
||||||
cond_br %1, ^bb3, ^bb4
|
|
||||||
|
|
||||||
// CHECK: ^bb3: // pred: ^bb2
|
|
||||||
// CHECK-NEXT: {{.*}} = spv.IAdd {{.*}}, {{.*}} : i32
|
|
||||||
// CHECK-NEXT: spv.Branch ^bb2({{.*}} : i32)
|
|
||||||
^bb3: // pred: ^bb2
|
|
||||||
%2 = arith.addi %0, %step : index
|
|
||||||
br ^bb2(%2 : index)
|
|
||||||
|
|
||||||
// CHECK: ^bb4: // pred: ^bb2
|
|
||||||
^bb4: // pred: ^bb2
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
|
@ -56,9 +56,9 @@ func @affine_load_invalid_dim(%M : memref<10xi32>) {
|
||||||
^bb0(%arg: index):
|
^bb0(%arg: index):
|
||||||
affine.load %M[%arg] : memref<10xi32>
|
affine.load %M[%arg] : memref<10xi32>
|
||||||
// expected-error@-1 {{index must be a dimension or symbol identifier}}
|
// expected-error@-1 {{index must be a dimension or symbol identifier}}
|
||||||
br ^bb1
|
cf.br ^bb1
|
||||||
^bb1:
|
^bb1:
|
||||||
br ^bb1
|
cf.br ^bb1
|
||||||
}) : () -> ()
|
}) : () -> ()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -54,13 +54,13 @@ func @token_value_to_func() {
|
||||||
// CHECK-LABEL: @token_arg_cond_br_await_with_fallthough
|
// CHECK-LABEL: @token_arg_cond_br_await_with_fallthough
|
||||||
// CHECK: %[[TOKEN:.*]]: !async.token
|
// CHECK: %[[TOKEN:.*]]: !async.token
|
||||||
func @token_arg_cond_br_await_with_fallthough(%arg0: !async.token, %arg1: i1) {
|
func @token_arg_cond_br_await_with_fallthough(%arg0: !async.token, %arg1: i1) {
|
||||||
// CHECK: cond_br
|
// CHECK: cf.cond_br
|
||||||
// CHECK-SAME: ^[[BB1:.*]], ^[[BB2:.*]]
|
// CHECK-SAME: ^[[BB1:.*]], ^[[BB2:.*]]
|
||||||
cond_br %arg1, ^bb1, ^bb2
|
cf.cond_br %arg1, ^bb1, ^bb2
|
||||||
^bb1:
|
^bb1:
|
||||||
// CHECK: ^[[BB1]]:
|
// CHECK: ^[[BB1]]:
|
||||||
// CHECK: br ^[[BB2]]
|
// CHECK: cf.br ^[[BB2]]
|
||||||
br ^bb2
|
cf.br ^bb2
|
||||||
^bb2:
|
^bb2:
|
||||||
// CHECK: ^[[BB2]]:
|
// CHECK: ^[[BB2]]:
|
||||||
// CHECK: async.runtime.await %[[TOKEN]]
|
// CHECK: async.runtime.await %[[TOKEN]]
|
||||||
|
@ -88,10 +88,10 @@ func @token_coro_return() -> !async.token {
|
||||||
async.runtime.resume %hdl
|
async.runtime.resume %hdl
|
||||||
async.coro.suspend %saved, ^suspend, ^resume, ^cleanup
|
async.coro.suspend %saved, ^suspend, ^resume, ^cleanup
|
||||||
^resume:
|
^resume:
|
||||||
br ^cleanup
|
cf.br ^cleanup
|
||||||
^cleanup:
|
^cleanup:
|
||||||
async.coro.free %id, %hdl
|
async.coro.free %id, %hdl
|
||||||
br ^suspend
|
cf.br ^suspend
|
||||||
^suspend:
|
^suspend:
|
||||||
async.coro.end %hdl
|
async.coro.end %hdl
|
||||||
return %token : !async.token
|
return %token : !async.token
|
||||||
|
@ -109,10 +109,10 @@ func @token_coro_await_and_resume(%arg0: !async.token) -> !async.token {
|
||||||
// CHECK-NEXT: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i64}
|
// CHECK-NEXT: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i64}
|
||||||
async.coro.suspend %saved, ^suspend, ^resume, ^cleanup
|
async.coro.suspend %saved, ^suspend, ^resume, ^cleanup
|
||||||
^resume:
|
^resume:
|
||||||
br ^cleanup
|
cf.br ^cleanup
|
||||||
^cleanup:
|
^cleanup:
|
||||||
async.coro.free %id, %hdl
|
async.coro.free %id, %hdl
|
||||||
br ^suspend
|
cf.br ^suspend
|
||||||
^suspend:
|
^suspend:
|
||||||
async.coro.end %hdl
|
async.coro.end %hdl
|
||||||
return %token : !async.token
|
return %token : !async.token
|
||||||
|
@ -137,10 +137,10 @@ func @value_coro_await_and_resume(%arg0: !async.value<f32>) -> !async.token {
|
||||||
%0 = async.runtime.load %arg0 : !async.value<f32>
|
%0 = async.runtime.load %arg0 : !async.value<f32>
|
||||||
// CHECK: arith.addf %[[LOADED]], %[[LOADED]]
|
// CHECK: arith.addf %[[LOADED]], %[[LOADED]]
|
||||||
%1 = arith.addf %0, %0 : f32
|
%1 = arith.addf %0, %0 : f32
|
||||||
br ^cleanup
|
cf.br ^cleanup
|
||||||
^cleanup:
|
^cleanup:
|
||||||
async.coro.free %id, %hdl
|
async.coro.free %id, %hdl
|
||||||
br ^suspend
|
cf.br ^suspend
|
||||||
^suspend:
|
^suspend:
|
||||||
async.coro.end %hdl
|
async.coro.end %hdl
|
||||||
return %token : !async.token
|
return %token : !async.token
|
||||||
|
@ -167,12 +167,12 @@ func private @outlined_async_execute(%arg0: !async.token) -> !async.token {
|
||||||
// CHECK: ^[[RESUME_1:.*]]:
|
// CHECK: ^[[RESUME_1:.*]]:
|
||||||
// CHECK: async.runtime.set_available
|
// CHECK: async.runtime.set_available
|
||||||
async.runtime.set_available %0 : !async.token
|
async.runtime.set_available %0 : !async.token
|
||||||
br ^cleanup
|
cf.br ^cleanup
|
||||||
^cleanup:
|
^cleanup:
|
||||||
// CHECK: ^[[CLEANUP:.*]]:
|
// CHECK: ^[[CLEANUP:.*]]:
|
||||||
// CHECK: async.coro.free
|
// CHECK: async.coro.free
|
||||||
async.coro.free %1, %2
|
async.coro.free %1, %2
|
||||||
br ^suspend
|
cf.br ^suspend
|
||||||
^suspend:
|
^suspend:
|
||||||
// CHECK: ^[[SUSPEND:.*]]:
|
// CHECK: ^[[SUSPEND:.*]]:
|
||||||
// CHECK: async.coro.end
|
// CHECK: async.coro.end
|
||||||
|
@ -198,7 +198,7 @@ func @token_await_inside_nested_region(%arg0: i1) {
|
||||||
|
|
||||||
// CHECK-LABEL: @token_defined_in_the_loop
|
// CHECK-LABEL: @token_defined_in_the_loop
|
||||||
func @token_defined_in_the_loop() {
|
func @token_defined_in_the_loop() {
|
||||||
br ^bb1
|
cf.br ^bb1
|
||||||
^bb1:
|
^bb1:
|
||||||
// CHECK: ^[[BB1:.*]]:
|
// CHECK: ^[[BB1:.*]]:
|
||||||
// CHECK: %[[TOKEN:.*]] = call @token()
|
// CHECK: %[[TOKEN:.*]] = call @token()
|
||||||
|
@ -207,7 +207,7 @@ func @token_defined_in_the_loop() {
|
||||||
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i64}
|
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i64}
|
||||||
async.runtime.await %token : !async.token
|
async.runtime.await %token : !async.token
|
||||||
%0 = call @cond(): () -> (i1)
|
%0 = call @cond(): () -> (i1)
|
||||||
cond_br %0, ^bb1, ^bb2
|
cf.cond_br %0, ^bb1, ^bb2
|
||||||
^bb2:
|
^bb2:
|
||||||
// CHECK: ^[[BB2:.*]]:
|
// CHECK: ^[[BB2:.*]]:
|
||||||
// CHECK: return
|
// CHECK: return
|
||||||
|
@ -218,18 +218,18 @@ func @token_defined_in_the_loop() {
|
||||||
func @divergent_liveness_one_token(%arg0 : i1) {
|
func @divergent_liveness_one_token(%arg0 : i1) {
|
||||||
// CHECK: %[[TOKEN:.*]] = call @token()
|
// CHECK: %[[TOKEN:.*]] = call @token()
|
||||||
%token = call @token() : () -> !async.token
|
%token = call @token() : () -> !async.token
|
||||||
// CHECK: cond_br %arg0, ^[[LIVE_IN:.*]], ^[[REF_COUNTING:.*]]
|
// CHECK: cf.cond_br %arg0, ^[[LIVE_IN:.*]], ^[[REF_COUNTING:.*]]
|
||||||
cond_br %arg0, ^bb1, ^bb2
|
cf.cond_br %arg0, ^bb1, ^bb2
|
||||||
^bb1:
|
^bb1:
|
||||||
// CHECK: ^[[LIVE_IN]]:
|
// CHECK: ^[[LIVE_IN]]:
|
||||||
// CHECK: async.runtime.await %[[TOKEN]]
|
// CHECK: async.runtime.await %[[TOKEN]]
|
||||||
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i64}
|
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i64}
|
||||||
// CHECK: br ^[[RETURN:.*]]
|
// CHECK: cf.br ^[[RETURN:.*]]
|
||||||
async.runtime.await %token : !async.token
|
async.runtime.await %token : !async.token
|
||||||
br ^bb2
|
cf.br ^bb2
|
||||||
// CHECK: ^[[REF_COUNTING:.*]]:
|
// CHECK: ^[[REF_COUNTING:.*]]:
|
||||||
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i64}
|
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i64}
|
||||||
// CHECK: br ^[[RETURN:.*]]
|
// CHECK: cf.br ^[[RETURN:.*]]
|
||||||
^bb2:
|
^bb2:
|
||||||
// CHECK: ^[[RETURN]]:
|
// CHECK: ^[[RETURN]]:
|
||||||
// CHECK: return
|
// CHECK: return
|
||||||
|
@ -240,20 +240,20 @@ func @divergent_liveness_one_token(%arg0 : i1) {
|
||||||
func @divergent_liveness_unique_predecessor(%arg0 : i1) {
|
func @divergent_liveness_unique_predecessor(%arg0 : i1) {
|
||||||
// CHECK: %[[TOKEN:.*]] = call @token()
|
// CHECK: %[[TOKEN:.*]] = call @token()
|
||||||
%token = call @token() : () -> !async.token
|
%token = call @token() : () -> !async.token
|
||||||
// CHECK: cond_br %arg0, ^[[LIVE_IN:.*]], ^[[NO_LIVE_IN:.*]]
|
// CHECK: cf.cond_br %arg0, ^[[LIVE_IN:.*]], ^[[NO_LIVE_IN:.*]]
|
||||||
cond_br %arg0, ^bb2, ^bb1
|
cf.cond_br %arg0, ^bb2, ^bb1
|
||||||
^bb1:
|
^bb1:
|
||||||
// CHECK: ^[[NO_LIVE_IN]]:
|
// CHECK: ^[[NO_LIVE_IN]]:
|
||||||
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i64}
|
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i64}
|
||||||
// CHECK: br ^[[RETURN:.*]]
|
// CHECK: cf.br ^[[RETURN:.*]]
|
||||||
br ^bb3
|
cf.br ^bb3
|
||||||
^bb2:
|
^bb2:
|
||||||
// CHECK: ^[[LIVE_IN]]:
|
// CHECK: ^[[LIVE_IN]]:
|
||||||
// CHECK: async.runtime.await %[[TOKEN]]
|
// CHECK: async.runtime.await %[[TOKEN]]
|
||||||
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i64}
|
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i64}
|
||||||
// CHECK: br ^[[RETURN]]
|
// CHECK: cf.br ^[[RETURN]]
|
||||||
async.runtime.await %token : !async.token
|
async.runtime.await %token : !async.token
|
||||||
br ^bb3
|
cf.br ^bb3
|
||||||
^bb3:
|
^bb3:
|
||||||
// CHECK: ^[[RETURN]]:
|
// CHECK: ^[[RETURN]]:
|
||||||
// CHECK: return
|
// CHECK: return
|
||||||
|
@ -266,24 +266,24 @@ func @divergent_liveness_two_tokens(%arg0 : i1) {
|
||||||
// CHECK: %[[TOKEN1:.*]] = call @token()
|
// CHECK: %[[TOKEN1:.*]] = call @token()
|
||||||
%token0 = call @token() : () -> !async.token
|
%token0 = call @token() : () -> !async.token
|
||||||
%token1 = call @token() : () -> !async.token
|
%token1 = call @token() : () -> !async.token
|
||||||
// CHECK: cond_br %arg0, ^[[AWAIT0:.*]], ^[[AWAIT1:.*]]
|
// CHECK: cf.cond_br %arg0, ^[[AWAIT0:.*]], ^[[AWAIT1:.*]]
|
||||||
cond_br %arg0, ^await0, ^await1
|
cf.cond_br %arg0, ^await0, ^await1
|
||||||
^await0:
|
^await0:
|
||||||
// CHECK: ^[[AWAIT0]]:
|
// CHECK: ^[[AWAIT0]]:
|
||||||
// CHECK: async.runtime.drop_ref %[[TOKEN1]] {count = 1 : i64}
|
// CHECK: async.runtime.drop_ref %[[TOKEN1]] {count = 1 : i64}
|
||||||
// CHECK: async.runtime.await %[[TOKEN0]]
|
// CHECK: async.runtime.await %[[TOKEN0]]
|
||||||
// CHECK: async.runtime.drop_ref %[[TOKEN0]] {count = 1 : i64}
|
// CHECK: async.runtime.drop_ref %[[TOKEN0]] {count = 1 : i64}
|
||||||
// CHECK: br ^[[RETURN:.*]]
|
// CHECK: cf.br ^[[RETURN:.*]]
|
||||||
async.runtime.await %token0 : !async.token
|
async.runtime.await %token0 : !async.token
|
||||||
br ^ret
|
cf.br ^ret
|
||||||
^await1:
|
^await1:
|
||||||
// CHECK: ^[[AWAIT1]]:
|
// CHECK: ^[[AWAIT1]]:
|
||||||
// CHECK: async.runtime.drop_ref %[[TOKEN0]] {count = 1 : i64}
|
// CHECK: async.runtime.drop_ref %[[TOKEN0]] {count = 1 : i64}
|
||||||
// CHECK: async.runtime.await %[[TOKEN1]]
|
// CHECK: async.runtime.await %[[TOKEN1]]
|
||||||
// CHECK: async.runtime.drop_ref %[[TOKEN1]] {count = 1 : i64}
|
// CHECK: async.runtime.drop_ref %[[TOKEN1]] {count = 1 : i64}
|
||||||
// CHECK: br ^[[RETURN]]
|
// CHECK: cf.br ^[[RETURN]]
|
||||||
async.runtime.await %token1 : !async.token
|
async.runtime.await %token1 : !async.token
|
||||||
br ^ret
|
cf.br ^ret
|
||||||
^ret:
|
^ret:
|
||||||
// CHECK: ^[[RETURN]]:
|
// CHECK: ^[[RETURN]]:
|
||||||
// CHECK: return
|
// CHECK: return
|
||||||
|
|
|
@ -10,7 +10,7 @@ func @simple_callee(%arg0: f32) -> (f32 {builtin.foo = "bar"}) {
|
||||||
// CHECK: %[[RETURNED_STORAGE:.*]] = async.runtime.create : !async.value<f32>
|
// CHECK: %[[RETURNED_STORAGE:.*]] = async.runtime.create : !async.value<f32>
|
||||||
// CHECK: %[[ID:.*]] = async.coro.id
|
// CHECK: %[[ID:.*]] = async.coro.id
|
||||||
// CHECK: %[[HDL:.*]] = async.coro.begin %[[ID]]
|
// CHECK: %[[HDL:.*]] = async.coro.begin %[[ID]]
|
||||||
// CHECK: br ^[[ORIGINAL_ENTRY:.*]]
|
// CHECK: cf.br ^[[ORIGINAL_ENTRY:.*]]
|
||||||
// CHECK ^[[ORIGINAL_ENTRY]]:
|
// CHECK ^[[ORIGINAL_ENTRY]]:
|
||||||
// CHECK: %[[VAL:.*]] = arith.addf %[[ARG]], %[[ARG]] : f32
|
// CHECK: %[[VAL:.*]] = arith.addf %[[ARG]], %[[ARG]] : f32
|
||||||
%0 = arith.addf %arg0, %arg0 : f32
|
%0 = arith.addf %arg0, %arg0 : f32
|
||||||
|
@ -29,7 +29,7 @@ func @simple_callee(%arg0: f32) -> (f32 {builtin.foo = "bar"}) {
|
||||||
|
|
||||||
// CHECK: ^[[RESUME]]:
|
// CHECK: ^[[RESUME]]:
|
||||||
// CHECK: %[[IS_ERROR:.*]] = async.runtime.is_error %[[VAL_STORAGE]] : !async.value<f32>
|
// CHECK: %[[IS_ERROR:.*]] = async.runtime.is_error %[[VAL_STORAGE]] : !async.value<f32>
|
||||||
// CHECK: cond_br %[[IS_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_OK:.*]]
|
// CHECK: cf.cond_br %[[IS_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_OK:.*]]
|
||||||
|
|
||||||
// CHECK: ^[[BRANCH_OK]]:
|
// CHECK: ^[[BRANCH_OK]]:
|
||||||
// CHECK: %[[LOADED:.*]] = async.runtime.load %[[VAL_STORAGE]] : <f32>
|
// CHECK: %[[LOADED:.*]] = async.runtime.load %[[VAL_STORAGE]] : <f32>
|
||||||
|
@ -37,19 +37,19 @@ func @simple_callee(%arg0: f32) -> (f32 {builtin.foo = "bar"}) {
|
||||||
// CHECK: async.runtime.store %[[RETURNED]], %[[RETURNED_STORAGE]] : <f32>
|
// CHECK: async.runtime.store %[[RETURNED]], %[[RETURNED_STORAGE]] : <f32>
|
||||||
// CHECK: async.runtime.set_available %[[RETURNED_STORAGE]]
|
// CHECK: async.runtime.set_available %[[RETURNED_STORAGE]]
|
||||||
// CHECK: async.runtime.set_available %[[TOKEN]]
|
// CHECK: async.runtime.set_available %[[TOKEN]]
|
||||||
// CHECK: br ^[[CLEANUP]]
|
// CHECK: cf.br ^[[CLEANUP]]
|
||||||
%3 = arith.mulf %arg0, %2 : f32
|
%3 = arith.mulf %arg0, %2 : f32
|
||||||
return %3: f32
|
return %3: f32
|
||||||
|
|
||||||
// CHECK: ^[[BRANCH_ERROR]]:
|
// CHECK: ^[[BRANCH_ERROR]]:
|
||||||
// CHECK: async.runtime.set_error %[[TOKEN]]
|
// CHECK: async.runtime.set_error %[[TOKEN]]
|
||||||
// CHECK: async.runtime.set_error %[[RETURNED_STORAGE]]
|
// CHECK: async.runtime.set_error %[[RETURNED_STORAGE]]
|
||||||
// CHECK: br ^[[CLEANUP]]
|
// CHECK: cf.br ^[[CLEANUP]]
|
||||||
|
|
||||||
|
|
||||||
// CHECK: ^[[CLEANUP]]:
|
// CHECK: ^[[CLEANUP]]:
|
||||||
// CHECK: async.coro.free %[[ID]], %[[HDL]]
|
// CHECK: async.coro.free %[[ID]], %[[HDL]]
|
||||||
// CHECK: br ^[[SUSPEND]]
|
// CHECK: cf.br ^[[SUSPEND]]
|
||||||
|
|
||||||
// CHECK: ^[[SUSPEND]]:
|
// CHECK: ^[[SUSPEND]]:
|
||||||
// CHECK: async.coro.end %[[HDL]]
|
// CHECK: async.coro.end %[[HDL]]
|
||||||
|
@ -63,7 +63,7 @@ func @simple_caller() -> f32 {
|
||||||
// CHECK: %[[RETURNED_STORAGE:.*]] = async.runtime.create : !async.value<f32>
|
// CHECK: %[[RETURNED_STORAGE:.*]] = async.runtime.create : !async.value<f32>
|
||||||
// CHECK: %[[ID:.*]] = async.coro.id
|
// CHECK: %[[ID:.*]] = async.coro.id
|
||||||
// CHECK: %[[HDL:.*]] = async.coro.begin %[[ID]]
|
// CHECK: %[[HDL:.*]] = async.coro.begin %[[ID]]
|
||||||
// CHECK: br ^[[ORIGINAL_ENTRY:.*]]
|
// CHECK: cf.br ^[[ORIGINAL_ENTRY:.*]]
|
||||||
// CHECK ^[[ORIGINAL_ENTRY]]:
|
// CHECK ^[[ORIGINAL_ENTRY]]:
|
||||||
|
|
||||||
// CHECK: %[[CONSTANT:.*]] = arith.constant
|
// CHECK: %[[CONSTANT:.*]] = arith.constant
|
||||||
|
@ -77,28 +77,28 @@ func @simple_caller() -> f32 {
|
||||||
|
|
||||||
// CHECK: ^[[RESUME]]:
|
// CHECK: ^[[RESUME]]:
|
||||||
// CHECK: %[[IS_TOKEN_ERROR:.*]] = async.runtime.is_error %[[RETURNED_TO_CALLER]]#0 : !async.token
|
// CHECK: %[[IS_TOKEN_ERROR:.*]] = async.runtime.is_error %[[RETURNED_TO_CALLER]]#0 : !async.token
|
||||||
// CHECK: cond_br %[[IS_TOKEN_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_TOKEN_OK:.*]]
|
// CHECK: cf.cond_br %[[IS_TOKEN_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_TOKEN_OK:.*]]
|
||||||
|
|
||||||
// CHECK: ^[[BRANCH_TOKEN_OK]]:
|
// CHECK: ^[[BRANCH_TOKEN_OK]]:
|
||||||
// CHECK: %[[IS_VALUE_ERROR:.*]] = async.runtime.is_error %[[RETURNED_TO_CALLER]]#1 : !async.value<f32>
|
// CHECK: %[[IS_VALUE_ERROR:.*]] = async.runtime.is_error %[[RETURNED_TO_CALLER]]#1 : !async.value<f32>
|
||||||
// CHECK: cond_br %[[IS_VALUE_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_VALUE_OK:.*]]
|
// CHECK: cf.cond_br %[[IS_VALUE_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_VALUE_OK:.*]]
|
||||||
|
|
||||||
// CHECK: ^[[BRANCH_VALUE_OK]]:
|
// CHECK: ^[[BRANCH_VALUE_OK]]:
|
||||||
// CHECK: %[[LOADED:.*]] = async.runtime.load %[[RETURNED_TO_CALLER]]#1 : <f32>
|
// CHECK: %[[LOADED:.*]] = async.runtime.load %[[RETURNED_TO_CALLER]]#1 : <f32>
|
||||||
// CHECK: async.runtime.store %[[LOADED]], %[[RETURNED_STORAGE]] : <f32>
|
// CHECK: async.runtime.store %[[LOADED]], %[[RETURNED_STORAGE]] : <f32>
|
||||||
// CHECK: async.runtime.set_available %[[RETURNED_STORAGE]]
|
// CHECK: async.runtime.set_available %[[RETURNED_STORAGE]]
|
||||||
// CHECK: async.runtime.set_available %[[TOKEN]]
|
// CHECK: async.runtime.set_available %[[TOKEN]]
|
||||||
// CHECK: br ^[[CLEANUP]]
|
// CHECK: cf.br ^[[CLEANUP]]
|
||||||
return %r: f32
|
return %r: f32
|
||||||
// CHECK: ^[[BRANCH_ERROR]]:
|
// CHECK: ^[[BRANCH_ERROR]]:
|
||||||
// CHECK: async.runtime.set_error %[[TOKEN]]
|
// CHECK: async.runtime.set_error %[[TOKEN]]
|
||||||
// CHECK: async.runtime.set_error %[[RETURNED_STORAGE]]
|
// CHECK: async.runtime.set_error %[[RETURNED_STORAGE]]
|
||||||
// CHECK: br ^[[CLEANUP]]
|
// CHECK: cf.br ^[[CLEANUP]]
|
||||||
|
|
||||||
|
|
||||||
// CHECK: ^[[CLEANUP]]:
|
// CHECK: ^[[CLEANUP]]:
|
||||||
// CHECK: async.coro.free %[[ID]], %[[HDL]]
|
// CHECK: async.coro.free %[[ID]], %[[HDL]]
|
||||||
// CHECK: br ^[[SUSPEND]]
|
// CHECK: cf.br ^[[SUSPEND]]
|
||||||
|
|
||||||
// CHECK: ^[[SUSPEND]]:
|
// CHECK: ^[[SUSPEND]]:
|
||||||
// CHECK: async.coro.end %[[HDL]]
|
// CHECK: async.coro.end %[[HDL]]
|
||||||
|
@ -112,7 +112,7 @@ func @double_caller() -> f32 {
|
||||||
// CHECK: %[[RETURNED_STORAGE:.*]] = async.runtime.create : !async.value<f32>
|
// CHECK: %[[RETURNED_STORAGE:.*]] = async.runtime.create : !async.value<f32>
|
||||||
// CHECK: %[[ID:.*]] = async.coro.id
|
// CHECK: %[[ID:.*]] = async.coro.id
|
||||||
// CHECK: %[[HDL:.*]] = async.coro.begin %[[ID]]
|
// CHECK: %[[HDL:.*]] = async.coro.begin %[[ID]]
|
||||||
// CHECK: br ^[[ORIGINAL_ENTRY:.*]]
|
// CHECK: cf.br ^[[ORIGINAL_ENTRY:.*]]
|
||||||
// CHECK ^[[ORIGINAL_ENTRY]]:
|
// CHECK ^[[ORIGINAL_ENTRY]]:
|
||||||
|
|
||||||
// CHECK: %[[CONSTANT:.*]] = arith.constant
|
// CHECK: %[[CONSTANT:.*]] = arith.constant
|
||||||
|
@ -126,11 +126,11 @@ func @double_caller() -> f32 {
|
||||||
|
|
||||||
// CHECK: ^[[RESUME_1]]:
|
// CHECK: ^[[RESUME_1]]:
|
||||||
// CHECK: %[[IS_TOKEN_ERROR_1:.*]] = async.runtime.is_error %[[RETURNED_TO_CALLER_1]]#0 : !async.token
|
// CHECK: %[[IS_TOKEN_ERROR_1:.*]] = async.runtime.is_error %[[RETURNED_TO_CALLER_1]]#0 : !async.token
|
||||||
// CHECK: cond_br %[[IS_TOKEN_ERROR_1]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_TOKEN_OK_1:.*]]
|
// CHECK: cf.cond_br %[[IS_TOKEN_ERROR_1]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_TOKEN_OK_1:.*]]
|
||||||
|
|
||||||
// CHECK: ^[[BRANCH_TOKEN_OK_1]]:
|
// CHECK: ^[[BRANCH_TOKEN_OK_1]]:
|
||||||
// CHECK: %[[IS_VALUE_ERROR_1:.*]] = async.runtime.is_error %[[RETURNED_TO_CALLER_1]]#1 : !async.value<f32>
|
// CHECK: %[[IS_VALUE_ERROR_1:.*]] = async.runtime.is_error %[[RETURNED_TO_CALLER_1]]#1 : !async.value<f32>
|
||||||
// CHECK: cond_br %[[IS_VALUE_ERROR_1]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_VALUE_OK_1:.*]]
|
// CHECK: cf.cond_br %[[IS_VALUE_ERROR_1]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_VALUE_OK_1:.*]]
|
||||||
|
|
||||||
// CHECK: ^[[BRANCH_VALUE_OK_1]]:
|
// CHECK: ^[[BRANCH_VALUE_OK_1]]:
|
||||||
// CHECK: %[[LOADED_1:.*]] = async.runtime.load %[[RETURNED_TO_CALLER_1]]#1 : <f32>
|
// CHECK: %[[LOADED_1:.*]] = async.runtime.load %[[RETURNED_TO_CALLER_1]]#1 : <f32>
|
||||||
|
@ -143,27 +143,27 @@ func @double_caller() -> f32 {
|
||||||
|
|
||||||
// CHECK: ^[[RESUME_2]]:
|
// CHECK: ^[[RESUME_2]]:
|
||||||
// CHECK: %[[IS_TOKEN_ERROR_2:.*]] = async.runtime.is_error %[[RETURNED_TO_CALLER_2]]#0 : !async.token
|
// CHECK: %[[IS_TOKEN_ERROR_2:.*]] = async.runtime.is_error %[[RETURNED_TO_CALLER_2]]#0 : !async.token
|
||||||
// CHECK: cond_br %[[IS_TOKEN_ERROR_2]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_TOKEN_OK_2:.*]]
|
// CHECK: cf.cond_br %[[IS_TOKEN_ERROR_2]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_TOKEN_OK_2:.*]]
|
||||||
|
|
||||||
// CHECK: ^[[BRANCH_TOKEN_OK_2]]:
|
// CHECK: ^[[BRANCH_TOKEN_OK_2]]:
|
||||||
// CHECK: %[[IS_VALUE_ERROR_2:.*]] = async.runtime.is_error %[[RETURNED_TO_CALLER_2]]#1 : !async.value<f32>
|
// CHECK: %[[IS_VALUE_ERROR_2:.*]] = async.runtime.is_error %[[RETURNED_TO_CALLER_2]]#1 : !async.value<f32>
|
||||||
// CHECK: cond_br %[[IS_VALUE_ERROR_2]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_VALUE_OK_2:.*]]
|
// CHECK: cf.cond_br %[[IS_VALUE_ERROR_2]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_VALUE_OK_2:.*]]
|
||||||
|
|
||||||
// CHECK: ^[[BRANCH_VALUE_OK_2]]:
|
// CHECK: ^[[BRANCH_VALUE_OK_2]]:
|
||||||
// CHECK: %[[LOADED_2:.*]] = async.runtime.load %[[RETURNED_TO_CALLER_2]]#1 : <f32>
|
// CHECK: %[[LOADED_2:.*]] = async.runtime.load %[[RETURNED_TO_CALLER_2]]#1 : <f32>
|
||||||
// CHECK: async.runtime.store %[[LOADED_2]], %[[RETURNED_STORAGE]] : <f32>
|
// CHECK: async.runtime.store %[[LOADED_2]], %[[RETURNED_STORAGE]] : <f32>
|
||||||
// CHECK: async.runtime.set_available %[[RETURNED_STORAGE]]
|
// CHECK: async.runtime.set_available %[[RETURNED_STORAGE]]
|
||||||
// CHECK: async.runtime.set_available %[[TOKEN]]
|
// CHECK: async.runtime.set_available %[[TOKEN]]
|
||||||
// CHECK: br ^[[CLEANUP]]
|
// CHECK: cf.br ^[[CLEANUP]]
|
||||||
return %s: f32
|
return %s: f32
|
||||||
// CHECK: ^[[BRANCH_ERROR]]:
|
// CHECK: ^[[BRANCH_ERROR]]:
|
||||||
// CHECK: async.runtime.set_error %[[TOKEN]]
|
// CHECK: async.runtime.set_error %[[TOKEN]]
|
||||||
// CHECK: async.runtime.set_error %[[RETURNED_STORAGE]]
|
// CHECK: async.runtime.set_error %[[RETURNED_STORAGE]]
|
||||||
// CHECK: br ^[[CLEANUP]]
|
// CHECK: cf.br ^[[CLEANUP]]
|
||||||
|
|
||||||
// CHECK: ^[[CLEANUP]]:
|
// CHECK: ^[[CLEANUP]]:
|
||||||
// CHECK: async.coro.free %[[ID]], %[[HDL]]
|
// CHECK: async.coro.free %[[ID]], %[[HDL]]
|
||||||
// CHECK: br ^[[SUSPEND]]
|
// CHECK: cf.br ^[[SUSPEND]]
|
||||||
|
|
||||||
// CHECK: ^[[SUSPEND]]:
|
// CHECK: ^[[SUSPEND]]:
|
||||||
// CHECK: async.coro.end %[[HDL]]
|
// CHECK: async.coro.end %[[HDL]]
|
||||||
|
@ -184,7 +184,7 @@ func @recursive(%arg: !async.token) {
|
||||||
async.await %arg : !async.token
|
async.await %arg : !async.token
|
||||||
// CHECK: ^[[RESUME_1]]:
|
// CHECK: ^[[RESUME_1]]:
|
||||||
// CHECK: %[[IS_ERROR:.*]] = async.runtime.is_error %[[ARG]] : !async.token
|
// CHECK: %[[IS_ERROR:.*]] = async.runtime.is_error %[[ARG]] : !async.token
|
||||||
// CHECK: cond_br %[[IS_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_OK:.*]]
|
// CHECK: cf.cond_br %[[IS_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_OK:.*]]
|
||||||
|
|
||||||
// CHECK: ^[[BRANCH_OK]]:
|
// CHECK: ^[[BRANCH_OK]]:
|
||||||
// CHECK: %[[GIVEN:.*]] = async.runtime.create : !async.token
|
// CHECK: %[[GIVEN:.*]] = async.runtime.create : !async.token
|
||||||
|
@ -200,16 +200,16 @@ call @recursive(%r): (!async.token) -> ()
|
||||||
|
|
||||||
// CHECK: ^[[RESUME_2]]:
|
// CHECK: ^[[RESUME_2]]:
|
||||||
// CHECK: async.runtime.set_available %[[TOKEN]]
|
// CHECK: async.runtime.set_available %[[TOKEN]]
|
||||||
// CHECK: br ^[[CLEANUP]]
|
// CHECK: cf.br ^[[CLEANUP]]
|
||||||
|
|
||||||
// CHECK: ^[[BRANCH_ERROR]]:
|
// CHECK: ^[[BRANCH_ERROR]]:
|
||||||
// CHECK: async.runtime.set_error %[[TOKEN]]
|
// CHECK: async.runtime.set_error %[[TOKEN]]
|
||||||
// CHECK: br ^[[CLEANUP]]
|
// CHECK: cf.br ^[[CLEANUP]]
|
||||||
return
|
return
|
||||||
|
|
||||||
// CHECK: ^[[CLEANUP]]:
|
// CHECK: ^[[CLEANUP]]:
|
||||||
// CHECK: async.coro.free %[[ID]], %[[HDL]]
|
// CHECK: async.coro.free %[[ID]], %[[HDL]]
|
||||||
// CHECK: br ^[[SUSPEND]]
|
// CHECK: cf.br ^[[SUSPEND]]
|
||||||
|
|
||||||
// CHECK: ^[[SUSPEND]]:
|
// CHECK: ^[[SUSPEND]]:
|
||||||
// CHECK: async.coro.end %[[HDL]]
|
// CHECK: async.coro.end %[[HDL]]
|
||||||
|
@ -230,7 +230,7 @@ func @corecursive1(%arg: !async.token) {
|
||||||
async.await %arg : !async.token
|
async.await %arg : !async.token
|
||||||
// CHECK: ^[[RESUME_1]]:
|
// CHECK: ^[[RESUME_1]]:
|
||||||
// CHECK: %[[IS_ERROR:.*]] = async.runtime.is_error %[[ARG]] : !async.token
|
// CHECK: %[[IS_ERROR:.*]] = async.runtime.is_error %[[ARG]] : !async.token
|
||||||
// CHECK: cond_br %[[IS_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_OK:.*]]
|
// CHECK: cf.cond_br %[[IS_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_OK:.*]]
|
||||||
|
|
||||||
// CHECK: ^[[BRANCH_OK]]:
|
// CHECK: ^[[BRANCH_OK]]:
|
||||||
// CHECK: %[[GIVEN:.*]] = async.runtime.create : !async.token
|
// CHECK: %[[GIVEN:.*]] = async.runtime.create : !async.token
|
||||||
|
@ -246,16 +246,16 @@ call @corecursive2(%r): (!async.token) -> ()
|
||||||
|
|
||||||
// CHECK: ^[[RESUME_2]]:
|
// CHECK: ^[[RESUME_2]]:
|
||||||
// CHECK: async.runtime.set_available %[[TOKEN]]
|
// CHECK: async.runtime.set_available %[[TOKEN]]
|
||||||
// CHECK: br ^[[CLEANUP]]
|
// CHECK: cf.br ^[[CLEANUP]]
|
||||||
|
|
||||||
// CHECK: ^[[BRANCH_ERROR]]:
|
// CHECK: ^[[BRANCH_ERROR]]:
|
||||||
// CHECK: async.runtime.set_error %[[TOKEN]]
|
// CHECK: async.runtime.set_error %[[TOKEN]]
|
||||||
// CHECK: br ^[[CLEANUP]]
|
// CHECK: cf.br ^[[CLEANUP]]
|
||||||
return
|
return
|
||||||
|
|
||||||
// CHECK: ^[[CLEANUP]]:
|
// CHECK: ^[[CLEANUP]]:
|
||||||
// CHECK: async.coro.free %[[ID]], %[[HDL]]
|
// CHECK: async.coro.free %[[ID]], %[[HDL]]
|
||||||
// CHECK: br ^[[SUSPEND]]
|
// CHECK: cf.br ^[[SUSPEND]]
|
||||||
|
|
||||||
// CHECK: ^[[SUSPEND]]:
|
// CHECK: ^[[SUSPEND]]:
|
||||||
// CHECK: async.coro.end %[[HDL]]
|
// CHECK: async.coro.end %[[HDL]]
|
||||||
|
@ -276,7 +276,7 @@ func @corecursive2(%arg: !async.token) {
|
||||||
async.await %arg : !async.token
|
async.await %arg : !async.token
|
||||||
// CHECK: ^[[RESUME_1]]:
|
// CHECK: ^[[RESUME_1]]:
|
||||||
// CHECK: %[[IS_ERROR:.*]] = async.runtime.is_error %[[ARG]] : !async.token
|
// CHECK: %[[IS_ERROR:.*]] = async.runtime.is_error %[[ARG]] : !async.token
|
||||||
// CHECK: cond_br %[[IS_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_OK:.*]]
|
// CHECK: cf.cond_br %[[IS_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_OK:.*]]
|
||||||
|
|
||||||
// CHECK: ^[[BRANCH_OK]]:
|
// CHECK: ^[[BRANCH_OK]]:
|
||||||
// CHECK: %[[GIVEN:.*]] = async.runtime.create : !async.token
|
// CHECK: %[[GIVEN:.*]] = async.runtime.create : !async.token
|
||||||
|
@ -292,16 +292,16 @@ call @corecursive1(%r): (!async.token) -> ()
|
||||||
|
|
||||||
// CHECK: ^[[RESUME_2]]:
|
// CHECK: ^[[RESUME_2]]:
|
||||||
// CHECK: async.runtime.set_available %[[TOKEN]]
|
// CHECK: async.runtime.set_available %[[TOKEN]]
|
||||||
// CHECK: br ^[[CLEANUP]]
|
// CHECK: cf.br ^[[CLEANUP]]
|
||||||
|
|
||||||
// CHECK: ^[[BRANCH_ERROR]]:
|
// CHECK: ^[[BRANCH_ERROR]]:
|
||||||
// CHECK: async.runtime.set_error %[[TOKEN]]
|
// CHECK: async.runtime.set_error %[[TOKEN]]
|
||||||
// CHECK: br ^[[CLEANUP]]
|
// CHECK: cf.br ^[[CLEANUP]]
|
||||||
return
|
return
|
||||||
|
|
||||||
// CHECK: ^[[CLEANUP]]:
|
// CHECK: ^[[CLEANUP]]:
|
||||||
// CHECK: async.coro.free %[[ID]], %[[HDL]]
|
// CHECK: async.coro.free %[[ID]], %[[HDL]]
|
||||||
// CHECK: br ^[[SUSPEND]]
|
// CHECK: cf.br ^[[SUSPEND]]
|
||||||
|
|
||||||
// CHECK: ^[[SUSPEND]]:
|
// CHECK: ^[[SUSPEND]]:
|
||||||
// CHECK: async.coro.end %[[HDL]]
|
// CHECK: async.coro.end %[[HDL]]
|
||||||
|
|
|
@ -63,7 +63,7 @@ func @nested_async_execute(%arg0: f32, %arg1: f32, %arg2: memref<1xf32>) {
|
||||||
// CHECK: %[[IS_ERROR:.*]] = async.runtime.is_error %[[TOKEN]]
|
// CHECK: %[[IS_ERROR:.*]] = async.runtime.is_error %[[TOKEN]]
|
||||||
// CHECK: %[[TRUE:.*]] = arith.constant true
|
// CHECK: %[[TRUE:.*]] = arith.constant true
|
||||||
// CHECK: %[[NOT_ERROR:.*]] = arith.xori %[[IS_ERROR]], %[[TRUE]] : i1
|
// CHECK: %[[NOT_ERROR:.*]] = arith.xori %[[IS_ERROR]], %[[TRUE]] : i1
|
||||||
// CHECK: assert %[[NOT_ERROR]]
|
// CHECK: cf.assert %[[NOT_ERROR]]
|
||||||
// CHECK-NEXT: return
|
// CHECK-NEXT: return
|
||||||
async.await %token0 : !async.token
|
async.await %token0 : !async.token
|
||||||
return
|
return
|
||||||
|
@ -109,7 +109,7 @@ func @nested_async_execute(%arg0: f32, %arg1: f32, %arg2: memref<1xf32>) {
|
||||||
// Check the error of the awaited token after resumption.
|
// Check the error of the awaited token after resumption.
|
||||||
// CHECK: ^[[RESUME_1]]:
|
// CHECK: ^[[RESUME_1]]:
|
||||||
// CHECK: %[[ERR:.*]] = async.runtime.is_error %[[INNER_TOKEN]]
|
// CHECK: %[[ERR:.*]] = async.runtime.is_error %[[INNER_TOKEN]]
|
||||||
// CHECK: cond_br %[[ERR]], ^[[SET_ERROR:.*]], ^[[CONTINUATION:.*]]
|
// CHECK: cf.cond_br %[[ERR]], ^[[SET_ERROR:.*]], ^[[CONTINUATION:.*]]
|
||||||
|
|
||||||
// Set token available if the token is not in the error state.
|
// Set token available if the token is not in the error state.
|
||||||
// CHECK: ^[[CONTINUATION:.*]]:
|
// CHECK: ^[[CONTINUATION:.*]]:
|
||||||
|
@ -169,7 +169,7 @@ func @async_execute_token_dependency(%arg0: f32, %arg1: memref<1xf32>) {
|
||||||
// Check the error of the awaited token after resumption.
|
// Check the error of the awaited token after resumption.
|
||||||
// CHECK: ^[[RESUME_1]]:
|
// CHECK: ^[[RESUME_1]]:
|
||||||
// CHECK: %[[ERR:.*]] = async.runtime.is_error %[[ARG0]]
|
// CHECK: %[[ERR:.*]] = async.runtime.is_error %[[ARG0]]
|
||||||
// CHECK: cond_br %[[ERR]], ^[[SET_ERROR:.*]], ^[[CONTINUATION:.*]]
|
// CHECK: cf.cond_br %[[ERR]], ^[[SET_ERROR:.*]], ^[[CONTINUATION:.*]]
|
||||||
|
|
||||||
// Emplace result token after second resumption and error checking.
|
// Emplace result token after second resumption and error checking.
|
||||||
// CHECK: ^[[CONTINUATION:.*]]:
|
// CHECK: ^[[CONTINUATION:.*]]:
|
||||||
|
@ -225,7 +225,7 @@ func @async_group_await_all(%arg0: f32, %arg1: memref<1xf32>) {
|
||||||
// Check the error of the awaited token after resumption.
|
// Check the error of the awaited token after resumption.
|
||||||
// CHECK: ^[[RESUME_1]]:
|
// CHECK: ^[[RESUME_1]]:
|
||||||
// CHECK: %[[ERR:.*]] = async.runtime.is_error %[[ARG]]
|
// CHECK: %[[ERR:.*]] = async.runtime.is_error %[[ARG]]
|
||||||
// CHECK: cond_br %[[ERR]], ^[[SET_ERROR:.*]], ^[[CONTINUATION:.*]]
|
// CHECK: cf.cond_br %[[ERR]], ^[[SET_ERROR:.*]], ^[[CONTINUATION:.*]]
|
||||||
|
|
||||||
// Emplace result token after error checking.
|
// Emplace result token after error checking.
|
||||||
// CHECK: ^[[CONTINUATION:.*]]:
|
// CHECK: ^[[CONTINUATION:.*]]:
|
||||||
|
@ -319,7 +319,7 @@ func @async_value_operands() {
|
||||||
// Check the error of the awaited token after resumption.
|
// Check the error of the awaited token after resumption.
|
||||||
// CHECK: ^[[RESUME_1]]:
|
// CHECK: ^[[RESUME_1]]:
|
||||||
// CHECK: %[[ERR:.*]] = async.runtime.is_error %[[ARG]]
|
// CHECK: %[[ERR:.*]] = async.runtime.is_error %[[ARG]]
|
||||||
// CHECK: cond_br %[[ERR]], ^[[SET_ERROR:.*]], ^[[CONTINUATION:.*]]
|
// CHECK: cf.cond_br %[[ERR]], ^[[SET_ERROR:.*]], ^[[CONTINUATION:.*]]
|
||||||
|
|
||||||
// // Load from the async.value argument after error checking.
|
// // Load from the async.value argument after error checking.
|
||||||
// CHECK: ^[[CONTINUATION:.*]]:
|
// CHECK: ^[[CONTINUATION:.*]]:
|
||||||
|
@ -335,7 +335,7 @@ func @async_value_operands() {
|
||||||
// CHECK-LABEL: @execute_assertion
|
// CHECK-LABEL: @execute_assertion
|
||||||
func @execute_assertion(%arg0: i1) {
|
func @execute_assertion(%arg0: i1) {
|
||||||
%token = async.execute {
|
%token = async.execute {
|
||||||
assert %arg0, "error"
|
cf.assert %arg0, "error"
|
||||||
async.yield
|
async.yield
|
||||||
}
|
}
|
||||||
async.await %token : !async.token
|
async.await %token : !async.token
|
||||||
|
@ -358,17 +358,17 @@ func @execute_assertion(%arg0: i1) {
|
||||||
|
|
||||||
// Resume coroutine after suspension.
|
// Resume coroutine after suspension.
|
||||||
// CHECK: ^[[RESUME]]:
|
// CHECK: ^[[RESUME]]:
|
||||||
// CHECK: cond_br %[[ARG0]], ^[[SET_AVAILABLE:.*]], ^[[SET_ERROR:.*]]
|
// CHECK: cf.cond_br %[[ARG0]], ^[[SET_AVAILABLE:.*]], ^[[SET_ERROR:.*]]
|
||||||
|
|
||||||
// Set coroutine completion token to available state.
|
// Set coroutine completion token to available state.
|
||||||
// CHECK: ^[[SET_AVAILABLE]]:
|
// CHECK: ^[[SET_AVAILABLE]]:
|
||||||
// CHECK: async.runtime.set_available %[[TOKEN]]
|
// CHECK: async.runtime.set_available %[[TOKEN]]
|
||||||
// CHECK: br ^[[CLEANUP]]
|
// CHECK: cf.br ^[[CLEANUP]]
|
||||||
|
|
||||||
// Set coroutine completion token to error state.
|
// Set coroutine completion token to error state.
|
||||||
// CHECK: ^[[SET_ERROR]]:
|
// CHECK: ^[[SET_ERROR]]:
|
||||||
// CHECK: async.runtime.set_error %[[TOKEN]]
|
// CHECK: async.runtime.set_error %[[TOKEN]]
|
||||||
// CHECK: br ^[[CLEANUP]]
|
// CHECK: cf.br ^[[CLEANUP]]
|
||||||
|
|
||||||
// Delete coroutine.
|
// Delete coroutine.
|
||||||
// CHECK: ^[[CLEANUP]]:
|
// CHECK: ^[[CLEANUP]]:
|
||||||
|
@ -409,7 +409,7 @@ func @lower_scf_to_cfg(%arg0: f32, %arg1: memref<1xf32>, %arg2: i1) {
|
||||||
|
|
||||||
// Check that structured control flow lowered to CFG.
|
// Check that structured control flow lowered to CFG.
|
||||||
// CHECK-NOT: scf.if
|
// CHECK-NOT: scf.if
|
||||||
// CHECK: cond_br %[[FLAG]]
|
// CHECK: cf.cond_br %[[FLAG]]
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
// Constants captured by the async.execute region should be cloned into the
|
// Constants captured by the async.execute region should be cloned into the
|
||||||
|
|
|
@ -17,26 +17,26 @@
|
||||||
|
|
||||||
// CHECK-LABEL: func @condBranch
|
// CHECK-LABEL: func @condBranch
|
||||||
func @condBranch(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
|
func @condBranch(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
|
||||||
cond_br %arg0, ^bb1, ^bb2
|
cf.cond_br %arg0, ^bb1, ^bb2
|
||||||
^bb1:
|
^bb1:
|
||||||
br ^bb3(%arg1 : memref<2xf32>)
|
cf.br ^bb3(%arg1 : memref<2xf32>)
|
||||||
^bb2:
|
^bb2:
|
||||||
%0 = memref.alloc() : memref<2xf32>
|
%0 = memref.alloc() : memref<2xf32>
|
||||||
test.buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>)
|
test.buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>)
|
||||||
br ^bb3(%0 : memref<2xf32>)
|
cf.br ^bb3(%0 : memref<2xf32>)
|
||||||
^bb3(%1: memref<2xf32>):
|
^bb3(%1: memref<2xf32>):
|
||||||
test.copy(%1, %arg2) : (memref<2xf32>, memref<2xf32>)
|
test.copy(%1, %arg2) : (memref<2xf32>, memref<2xf32>)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-NEXT: cond_br
|
// CHECK-NEXT: cf.cond_br
|
||||||
// CHECK: %[[ALLOC0:.*]] = bufferization.clone
|
// CHECK: %[[ALLOC0:.*]] = bufferization.clone
|
||||||
// CHECK-NEXT: br ^bb3(%[[ALLOC0]]
|
// CHECK-NEXT: cf.br ^bb3(%[[ALLOC0]]
|
||||||
// CHECK: %[[ALLOC1:.*]] = memref.alloc
|
// CHECK: %[[ALLOC1:.*]] = memref.alloc
|
||||||
// CHECK-NEXT: test.buffer_based
|
// CHECK-NEXT: test.buffer_based
|
||||||
// CHECK-NEXT: %[[ALLOC2:.*]] = bufferization.clone %[[ALLOC1]]
|
// CHECK-NEXT: %[[ALLOC2:.*]] = bufferization.clone %[[ALLOC1]]
|
||||||
// CHECK-NEXT: memref.dealloc %[[ALLOC1]]
|
// CHECK-NEXT: memref.dealloc %[[ALLOC1]]
|
||||||
// CHECK-NEXT: br ^bb3(%[[ALLOC2]]
|
// CHECK-NEXT: cf.br ^bb3(%[[ALLOC2]]
|
||||||
// CHECK: test.copy
|
// CHECK: test.copy
|
||||||
// CHECK-NEXT: memref.dealloc
|
// CHECK-NEXT: memref.dealloc
|
||||||
// CHECK-NEXT: return
|
// CHECK-NEXT: return
|
||||||
|
@ -62,27 +62,27 @@ func @condBranchDynamicType(
|
||||||
%arg1: memref<?xf32>,
|
%arg1: memref<?xf32>,
|
||||||
%arg2: memref<?xf32>,
|
%arg2: memref<?xf32>,
|
||||||
%arg3: index) {
|
%arg3: index) {
|
||||||
cond_br %arg0, ^bb1, ^bb2(%arg3: index)
|
cf.cond_br %arg0, ^bb1, ^bb2(%arg3: index)
|
||||||
^bb1:
|
^bb1:
|
||||||
br ^bb3(%arg1 : memref<?xf32>)
|
cf.br ^bb3(%arg1 : memref<?xf32>)
|
||||||
^bb2(%0: index):
|
^bb2(%0: index):
|
||||||
%1 = memref.alloc(%0) : memref<?xf32>
|
%1 = memref.alloc(%0) : memref<?xf32>
|
||||||
test.buffer_based in(%arg1: memref<?xf32>) out(%1: memref<?xf32>)
|
test.buffer_based in(%arg1: memref<?xf32>) out(%1: memref<?xf32>)
|
||||||
br ^bb3(%1 : memref<?xf32>)
|
cf.br ^bb3(%1 : memref<?xf32>)
|
||||||
^bb3(%2: memref<?xf32>):
|
^bb3(%2: memref<?xf32>):
|
||||||
test.copy(%2, %arg2) : (memref<?xf32>, memref<?xf32>)
|
test.copy(%2, %arg2) : (memref<?xf32>, memref<?xf32>)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-NEXT: cond_br
|
// CHECK-NEXT: cf.cond_br
|
||||||
// CHECK: %[[ALLOC0:.*]] = bufferization.clone
|
// CHECK: %[[ALLOC0:.*]] = bufferization.clone
|
||||||
// CHECK-NEXT: br ^bb3(%[[ALLOC0]]
|
// CHECK-NEXT: cf.br ^bb3(%[[ALLOC0]]
|
||||||
// CHECK: ^bb2(%[[IDX:.*]]:{{.*}})
|
// CHECK: ^bb2(%[[IDX:.*]]:{{.*}})
|
||||||
// CHECK-NEXT: %[[ALLOC1:.*]] = memref.alloc(%[[IDX]])
|
// CHECK-NEXT: %[[ALLOC1:.*]] = memref.alloc(%[[IDX]])
|
||||||
// CHECK-NEXT: test.buffer_based
|
// CHECK-NEXT: test.buffer_based
|
||||||
// CHECK-NEXT: %[[ALLOC2:.*]] = bufferization.clone
|
// CHECK-NEXT: %[[ALLOC2:.*]] = bufferization.clone
|
||||||
// CHECK-NEXT: memref.dealloc %[[ALLOC1]]
|
// CHECK-NEXT: memref.dealloc %[[ALLOC1]]
|
||||||
// CHECK-NEXT: br ^bb3
|
// CHECK-NEXT: cf.br ^bb3
|
||||||
// CHECK-NEXT: ^bb3(%[[ALLOC3:.*]]:{{.*}})
|
// CHECK-NEXT: ^bb3(%[[ALLOC3:.*]]:{{.*}})
|
||||||
// CHECK: test.copy(%[[ALLOC3]],
|
// CHECK: test.copy(%[[ALLOC3]],
|
||||||
// CHECK-NEXT: memref.dealloc %[[ALLOC3]]
|
// CHECK-NEXT: memref.dealloc %[[ALLOC3]]
|
||||||
|
@ -98,28 +98,28 @@ func @condBranchUnrankedType(
|
||||||
%arg1: memref<*xf32>,
|
%arg1: memref<*xf32>,
|
||||||
%arg2: memref<*xf32>,
|
%arg2: memref<*xf32>,
|
||||||
%arg3: index) {
|
%arg3: index) {
|
||||||
cond_br %arg0, ^bb1, ^bb2(%arg3: index)
|
cf.cond_br %arg0, ^bb1, ^bb2(%arg3: index)
|
||||||
^bb1:
|
^bb1:
|
||||||
br ^bb3(%arg1 : memref<*xf32>)
|
cf.br ^bb3(%arg1 : memref<*xf32>)
|
||||||
^bb2(%0: index):
|
^bb2(%0: index):
|
||||||
%1 = memref.alloc(%0) : memref<?xf32>
|
%1 = memref.alloc(%0) : memref<?xf32>
|
||||||
%2 = memref.cast %1 : memref<?xf32> to memref<*xf32>
|
%2 = memref.cast %1 : memref<?xf32> to memref<*xf32>
|
||||||
test.buffer_based in(%arg1: memref<*xf32>) out(%2: memref<*xf32>)
|
test.buffer_based in(%arg1: memref<*xf32>) out(%2: memref<*xf32>)
|
||||||
br ^bb3(%2 : memref<*xf32>)
|
cf.br ^bb3(%2 : memref<*xf32>)
|
||||||
^bb3(%3: memref<*xf32>):
|
^bb3(%3: memref<*xf32>):
|
||||||
test.copy(%3, %arg2) : (memref<*xf32>, memref<*xf32>)
|
test.copy(%3, %arg2) : (memref<*xf32>, memref<*xf32>)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-NEXT: cond_br
|
// CHECK-NEXT: cf.cond_br
|
||||||
// CHECK: %[[ALLOC0:.*]] = bufferization.clone
|
// CHECK: %[[ALLOC0:.*]] = bufferization.clone
|
||||||
// CHECK-NEXT: br ^bb3(%[[ALLOC0]]
|
// CHECK-NEXT: cf.br ^bb3(%[[ALLOC0]]
|
||||||
// CHECK: ^bb2(%[[IDX:.*]]:{{.*}})
|
// CHECK: ^bb2(%[[IDX:.*]]:{{.*}})
|
||||||
// CHECK-NEXT: %[[ALLOC1:.*]] = memref.alloc(%[[IDX]])
|
// CHECK-NEXT: %[[ALLOC1:.*]] = memref.alloc(%[[IDX]])
|
||||||
// CHECK: test.buffer_based
|
// CHECK: test.buffer_based
|
||||||
// CHECK-NEXT: %[[ALLOC2:.*]] = bufferization.clone
|
// CHECK-NEXT: %[[ALLOC2:.*]] = bufferization.clone
|
||||||
// CHECK-NEXT: memref.dealloc %[[ALLOC1]]
|
// CHECK-NEXT: memref.dealloc %[[ALLOC1]]
|
||||||
// CHECK-NEXT: br ^bb3
|
// CHECK-NEXT: cf.br ^bb3
|
||||||
// CHECK-NEXT: ^bb3(%[[ALLOC3:.*]]:{{.*}})
|
// CHECK-NEXT: ^bb3(%[[ALLOC3:.*]]:{{.*}})
|
||||||
// CHECK: test.copy(%[[ALLOC3]],
|
// CHECK: test.copy(%[[ALLOC3]],
|
||||||
// CHECK-NEXT: memref.dealloc %[[ALLOC3]]
|
// CHECK-NEXT: memref.dealloc %[[ALLOC3]]
|
||||||
|
@ -153,44 +153,44 @@ func @condBranchDynamicTypeNested(
|
||||||
%arg1: memref<?xf32>,
|
%arg1: memref<?xf32>,
|
||||||
%arg2: memref<?xf32>,
|
%arg2: memref<?xf32>,
|
||||||
%arg3: index) {
|
%arg3: index) {
|
||||||
cond_br %arg0, ^bb1, ^bb2(%arg3: index)
|
cf.cond_br %arg0, ^bb1, ^bb2(%arg3: index)
|
||||||
^bb1:
|
^bb1:
|
||||||
br ^bb6(%arg1 : memref<?xf32>)
|
cf.br ^bb6(%arg1 : memref<?xf32>)
|
||||||
^bb2(%0: index):
|
^bb2(%0: index):
|
||||||
%1 = memref.alloc(%0) : memref<?xf32>
|
%1 = memref.alloc(%0) : memref<?xf32>
|
||||||
test.buffer_based in(%arg1: memref<?xf32>) out(%1: memref<?xf32>)
|
test.buffer_based in(%arg1: memref<?xf32>) out(%1: memref<?xf32>)
|
||||||
cond_br %arg0, ^bb3, ^bb4
|
cf.cond_br %arg0, ^bb3, ^bb4
|
||||||
^bb3:
|
^bb3:
|
||||||
br ^bb5(%1 : memref<?xf32>)
|
cf.br ^bb5(%1 : memref<?xf32>)
|
||||||
^bb4:
|
^bb4:
|
||||||
br ^bb5(%1 : memref<?xf32>)
|
cf.br ^bb5(%1 : memref<?xf32>)
|
||||||
^bb5(%2: memref<?xf32>):
|
^bb5(%2: memref<?xf32>):
|
||||||
br ^bb6(%2 : memref<?xf32>)
|
cf.br ^bb6(%2 : memref<?xf32>)
|
||||||
^bb6(%3: memref<?xf32>):
|
^bb6(%3: memref<?xf32>):
|
||||||
br ^bb7(%3 : memref<?xf32>)
|
cf.br ^bb7(%3 : memref<?xf32>)
|
||||||
^bb7(%4: memref<?xf32>):
|
^bb7(%4: memref<?xf32>):
|
||||||
test.copy(%4, %arg2) : (memref<?xf32>, memref<?xf32>)
|
test.copy(%4, %arg2) : (memref<?xf32>, memref<?xf32>)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-NEXT: cond_br{{.*}}
|
// CHECK-NEXT: cf.cond_br{{.*}}
|
||||||
// CHECK-NEXT: ^bb1
|
// CHECK-NEXT: ^bb1
|
||||||
// CHECK-NEXT: %[[ALLOC0:.*]] = bufferization.clone
|
// CHECK-NEXT: %[[ALLOC0:.*]] = bufferization.clone
|
||||||
// CHECK-NEXT: br ^bb6(%[[ALLOC0]]
|
// CHECK-NEXT: cf.br ^bb6(%[[ALLOC0]]
|
||||||
// CHECK: ^bb2(%[[IDX:.*]]:{{.*}})
|
// CHECK: ^bb2(%[[IDX:.*]]:{{.*}})
|
||||||
// CHECK-NEXT: %[[ALLOC1:.*]] = memref.alloc(%[[IDX]])
|
// CHECK-NEXT: %[[ALLOC1:.*]] = memref.alloc(%[[IDX]])
|
||||||
// CHECK-NEXT: test.buffer_based
|
// CHECK-NEXT: test.buffer_based
|
||||||
// CHECK: cond_br
|
// CHECK: cf.cond_br
|
||||||
// CHECK: ^bb3:
|
// CHECK: ^bb3:
|
||||||
// CHECK-NEXT: br ^bb5(%[[ALLOC1]]{{.*}})
|
// CHECK-NEXT: cf.br ^bb5(%[[ALLOC1]]{{.*}})
|
||||||
// CHECK: ^bb4:
|
// CHECK: ^bb4:
|
||||||
// CHECK-NEXT: br ^bb5(%[[ALLOC1]]{{.*}})
|
// CHECK-NEXT: cf.br ^bb5(%[[ALLOC1]]{{.*}})
|
||||||
// CHECK-NEXT: ^bb5(%[[ALLOC2:.*]]:{{.*}})
|
// CHECK-NEXT: ^bb5(%[[ALLOC2:.*]]:{{.*}})
|
||||||
// CHECK-NEXT: %[[ALLOC3:.*]] = bufferization.clone %[[ALLOC2]]
|
// CHECK-NEXT: %[[ALLOC3:.*]] = bufferization.clone %[[ALLOC2]]
|
||||||
// CHECK-NEXT: memref.dealloc %[[ALLOC1]]
|
// CHECK-NEXT: memref.dealloc %[[ALLOC1]]
|
||||||
// CHECK-NEXT: br ^bb6(%[[ALLOC3]]{{.*}})
|
// CHECK-NEXT: cf.br ^bb6(%[[ALLOC3]]{{.*}})
|
||||||
// CHECK-NEXT: ^bb6(%[[ALLOC4:.*]]:{{.*}})
|
// CHECK-NEXT: ^bb6(%[[ALLOC4:.*]]:{{.*}})
|
||||||
// CHECK-NEXT: br ^bb7(%[[ALLOC4]]{{.*}})
|
// CHECK-NEXT: cf.br ^bb7(%[[ALLOC4]]{{.*}})
|
||||||
// CHECK-NEXT: ^bb7(%[[ALLOC5:.*]]:{{.*}})
|
// CHECK-NEXT: ^bb7(%[[ALLOC5:.*]]:{{.*}})
|
||||||
// CHECK: test.copy(%[[ALLOC5]],
|
// CHECK: test.copy(%[[ALLOC5]],
|
||||||
// CHECK-NEXT: memref.dealloc %[[ALLOC4]]
|
// CHECK-NEXT: memref.dealloc %[[ALLOC4]]
|
||||||
|
@ -225,18 +225,18 @@ func @emptyUsesValue(%arg0: memref<4xf32>) {
|
||||||
|
|
||||||
// CHECK-LABEL: func @criticalEdge
|
// CHECK-LABEL: func @criticalEdge
|
||||||
func @criticalEdge(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
|
func @criticalEdge(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
|
||||||
cond_br %arg0, ^bb1, ^bb2(%arg1 : memref<2xf32>)
|
cf.cond_br %arg0, ^bb1, ^bb2(%arg1 : memref<2xf32>)
|
||||||
^bb1:
|
^bb1:
|
||||||
%0 = memref.alloc() : memref<2xf32>
|
%0 = memref.alloc() : memref<2xf32>
|
||||||
test.buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>)
|
test.buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>)
|
||||||
br ^bb2(%0 : memref<2xf32>)
|
cf.br ^bb2(%0 : memref<2xf32>)
|
||||||
^bb2(%1: memref<2xf32>):
|
^bb2(%1: memref<2xf32>):
|
||||||
test.copy(%1, %arg2) : (memref<2xf32>, memref<2xf32>)
|
test.copy(%1, %arg2) : (memref<2xf32>, memref<2xf32>)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-NEXT: %[[ALLOC0:.*]] = bufferization.clone
|
// CHECK-NEXT: %[[ALLOC0:.*]] = bufferization.clone
|
||||||
// CHECK-NEXT: cond_br
|
// CHECK-NEXT: cf.cond_br
|
||||||
// CHECK: %[[ALLOC1:.*]] = memref.alloc()
|
// CHECK: %[[ALLOC1:.*]] = memref.alloc()
|
||||||
// CHECK-NEXT: test.buffer_based
|
// CHECK-NEXT: test.buffer_based
|
||||||
// CHECK-NEXT: %[[ALLOC2:.*]] = bufferization.clone %[[ALLOC1]]
|
// CHECK-NEXT: %[[ALLOC2:.*]] = bufferization.clone %[[ALLOC1]]
|
||||||
|
@ -260,9 +260,9 @@ func @criticalEdge(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
|
||||||
func @invCriticalEdge(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
|
func @invCriticalEdge(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
|
||||||
%0 = memref.alloc() : memref<2xf32>
|
%0 = memref.alloc() : memref<2xf32>
|
||||||
test.buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>)
|
test.buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>)
|
||||||
cond_br %arg0, ^bb1, ^bb2(%arg1 : memref<2xf32>)
|
cf.cond_br %arg0, ^bb1, ^bb2(%arg1 : memref<2xf32>)
|
||||||
^bb1:
|
^bb1:
|
||||||
br ^bb2(%0 : memref<2xf32>)
|
cf.br ^bb2(%0 : memref<2xf32>)
|
||||||
^bb2(%1: memref<2xf32>):
|
^bb2(%1: memref<2xf32>):
|
||||||
test.copy(%1, %arg2) : (memref<2xf32>, memref<2xf32>)
|
test.copy(%1, %arg2) : (memref<2xf32>, memref<2xf32>)
|
||||||
return
|
return
|
||||||
|
@ -288,13 +288,13 @@ func @invCriticalEdge(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
|
||||||
func @ifElse(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
|
func @ifElse(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
|
||||||
%0 = memref.alloc() : memref<2xf32>
|
%0 = memref.alloc() : memref<2xf32>
|
||||||
test.buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>)
|
test.buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>)
|
||||||
cond_br %arg0,
|
cf.cond_br %arg0,
|
||||||
^bb1(%arg1, %0 : memref<2xf32>, memref<2xf32>),
|
^bb1(%arg1, %0 : memref<2xf32>, memref<2xf32>),
|
||||||
^bb2(%0, %arg1 : memref<2xf32>, memref<2xf32>)
|
^bb2(%0, %arg1 : memref<2xf32>, memref<2xf32>)
|
||||||
^bb1(%1: memref<2xf32>, %2: memref<2xf32>):
|
^bb1(%1: memref<2xf32>, %2: memref<2xf32>):
|
||||||
br ^bb3(%1, %2 : memref<2xf32>, memref<2xf32>)
|
cf.br ^bb3(%1, %2 : memref<2xf32>, memref<2xf32>)
|
||||||
^bb2(%3: memref<2xf32>, %4: memref<2xf32>):
|
^bb2(%3: memref<2xf32>, %4: memref<2xf32>):
|
||||||
br ^bb3(%3, %4 : memref<2xf32>, memref<2xf32>)
|
cf.br ^bb3(%3, %4 : memref<2xf32>, memref<2xf32>)
|
||||||
^bb3(%5: memref<2xf32>, %6: memref<2xf32>):
|
^bb3(%5: memref<2xf32>, %6: memref<2xf32>):
|
||||||
%7 = memref.alloc() : memref<2xf32>
|
%7 = memref.alloc() : memref<2xf32>
|
||||||
test.buffer_based in(%5: memref<2xf32>) out(%7: memref<2xf32>)
|
test.buffer_based in(%5: memref<2xf32>) out(%7: memref<2xf32>)
|
||||||
|
@ -326,13 +326,13 @@ func @ifElse(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
|
||||||
func @ifElseNoUsers(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
|
func @ifElseNoUsers(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
|
||||||
%0 = memref.alloc() : memref<2xf32>
|
%0 = memref.alloc() : memref<2xf32>
|
||||||
test.buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>)
|
test.buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>)
|
||||||
cond_br %arg0,
|
cf.cond_br %arg0,
|
||||||
^bb1(%arg1, %0 : memref<2xf32>, memref<2xf32>),
|
^bb1(%arg1, %0 : memref<2xf32>, memref<2xf32>),
|
||||||
^bb2(%0, %arg1 : memref<2xf32>, memref<2xf32>)
|
^bb2(%0, %arg1 : memref<2xf32>, memref<2xf32>)
|
||||||
^bb1(%1: memref<2xf32>, %2: memref<2xf32>):
|
^bb1(%1: memref<2xf32>, %2: memref<2xf32>):
|
||||||
br ^bb3(%1, %2 : memref<2xf32>, memref<2xf32>)
|
cf.br ^bb3(%1, %2 : memref<2xf32>, memref<2xf32>)
|
||||||
^bb2(%3: memref<2xf32>, %4: memref<2xf32>):
|
^bb2(%3: memref<2xf32>, %4: memref<2xf32>):
|
||||||
br ^bb3(%3, %4 : memref<2xf32>, memref<2xf32>)
|
cf.br ^bb3(%3, %4 : memref<2xf32>, memref<2xf32>)
|
||||||
^bb3(%5: memref<2xf32>, %6: memref<2xf32>):
|
^bb3(%5: memref<2xf32>, %6: memref<2xf32>):
|
||||||
test.copy(%arg1, %arg2) : (memref<2xf32>, memref<2xf32>)
|
test.copy(%arg1, %arg2) : (memref<2xf32>, memref<2xf32>)
|
||||||
return
|
return
|
||||||
|
@ -361,17 +361,17 @@ func @ifElseNoUsers(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
|
||||||
func @ifElseNested(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
|
func @ifElseNested(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
|
||||||
%0 = memref.alloc() : memref<2xf32>
|
%0 = memref.alloc() : memref<2xf32>
|
||||||
test.buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>)
|
test.buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>)
|
||||||
cond_br %arg0,
|
cf.cond_br %arg0,
|
||||||
^bb1(%arg1, %0 : memref<2xf32>, memref<2xf32>),
|
^bb1(%arg1, %0 : memref<2xf32>, memref<2xf32>),
|
||||||
^bb2(%0, %arg1 : memref<2xf32>, memref<2xf32>)
|
^bb2(%0, %arg1 : memref<2xf32>, memref<2xf32>)
|
||||||
^bb1(%1: memref<2xf32>, %2: memref<2xf32>):
|
^bb1(%1: memref<2xf32>, %2: memref<2xf32>):
|
||||||
br ^bb5(%1, %2 : memref<2xf32>, memref<2xf32>)
|
cf.br ^bb5(%1, %2 : memref<2xf32>, memref<2xf32>)
|
||||||
^bb2(%3: memref<2xf32>, %4: memref<2xf32>):
|
^bb2(%3: memref<2xf32>, %4: memref<2xf32>):
|
||||||
cond_br %arg0, ^bb3(%3 : memref<2xf32>), ^bb4(%4 : memref<2xf32>)
|
cf.cond_br %arg0, ^bb3(%3 : memref<2xf32>), ^bb4(%4 : memref<2xf32>)
|
||||||
^bb3(%5: memref<2xf32>):
|
^bb3(%5: memref<2xf32>):
|
||||||
br ^bb5(%5, %3 : memref<2xf32>, memref<2xf32>)
|
cf.br ^bb5(%5, %3 : memref<2xf32>, memref<2xf32>)
|
||||||
^bb4(%6: memref<2xf32>):
|
^bb4(%6: memref<2xf32>):
|
||||||
br ^bb5(%3, %6 : memref<2xf32>, memref<2xf32>)
|
cf.br ^bb5(%3, %6 : memref<2xf32>, memref<2xf32>)
|
||||||
^bb5(%7: memref<2xf32>, %8: memref<2xf32>):
|
^bb5(%7: memref<2xf32>, %8: memref<2xf32>):
|
||||||
%9 = memref.alloc() : memref<2xf32>
|
%9 = memref.alloc() : memref<2xf32>
|
||||||
test.buffer_based in(%7: memref<2xf32>) out(%9: memref<2xf32>)
|
test.buffer_based in(%7: memref<2xf32>) out(%9: memref<2xf32>)
|
||||||
|
@ -430,33 +430,33 @@ func @moving_alloc_and_inserting_missing_dealloc(
|
||||||
%cond: i1,
|
%cond: i1,
|
||||||
%arg0: memref<2xf32>,
|
%arg0: memref<2xf32>,
|
||||||
%arg1: memref<2xf32>) {
|
%arg1: memref<2xf32>) {
|
||||||
cond_br %cond, ^bb1, ^bb2
|
cf.cond_br %cond, ^bb1, ^bb2
|
||||||
^bb1:
|
^bb1:
|
||||||
%0 = memref.alloc() : memref<2xf32>
|
%0 = memref.alloc() : memref<2xf32>
|
||||||
test.buffer_based in(%arg0: memref<2xf32>) out(%0: memref<2xf32>)
|
test.buffer_based in(%arg0: memref<2xf32>) out(%0: memref<2xf32>)
|
||||||
br ^exit(%0 : memref<2xf32>)
|
cf.br ^exit(%0 : memref<2xf32>)
|
||||||
^bb2:
|
^bb2:
|
||||||
%1 = memref.alloc() : memref<2xf32>
|
%1 = memref.alloc() : memref<2xf32>
|
||||||
test.buffer_based in(%arg0: memref<2xf32>) out(%1: memref<2xf32>)
|
test.buffer_based in(%arg0: memref<2xf32>) out(%1: memref<2xf32>)
|
||||||
br ^exit(%1 : memref<2xf32>)
|
cf.br ^exit(%1 : memref<2xf32>)
|
||||||
^exit(%arg2: memref<2xf32>):
|
^exit(%arg2: memref<2xf32>):
|
||||||
test.copy(%arg2, %arg1) : (memref<2xf32>, memref<2xf32>)
|
test.copy(%arg2, %arg1) : (memref<2xf32>, memref<2xf32>)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-NEXT: cond_br{{.*}}
|
// CHECK-NEXT: cf.cond_br{{.*}}
|
||||||
// CHECK-NEXT: ^bb1
|
// CHECK-NEXT: ^bb1
|
||||||
// CHECK: %[[ALLOC0:.*]] = memref.alloc()
|
// CHECK: %[[ALLOC0:.*]] = memref.alloc()
|
||||||
// CHECK-NEXT: test.buffer_based
|
// CHECK-NEXT: test.buffer_based
|
||||||
// CHECK-NEXT: %[[ALLOC1:.*]] = bufferization.clone %[[ALLOC0]]
|
// CHECK-NEXT: %[[ALLOC1:.*]] = bufferization.clone %[[ALLOC0]]
|
||||||
// CHECK-NEXT: memref.dealloc %[[ALLOC0]]
|
// CHECK-NEXT: memref.dealloc %[[ALLOC0]]
|
||||||
// CHECK-NEXT: br ^bb3(%[[ALLOC1]]
|
// CHECK-NEXT: cf.br ^bb3(%[[ALLOC1]]
|
||||||
// CHECK-NEXT: ^bb2
|
// CHECK-NEXT: ^bb2
|
||||||
// CHECK-NEXT: %[[ALLOC2:.*]] = memref.alloc()
|
// CHECK-NEXT: %[[ALLOC2:.*]] = memref.alloc()
|
||||||
// CHECK-NEXT: test.buffer_based
|
// CHECK-NEXT: test.buffer_based
|
||||||
// CHECK-NEXT: %[[ALLOC3:.*]] = bufferization.clone %[[ALLOC2]]
|
// CHECK-NEXT: %[[ALLOC3:.*]] = bufferization.clone %[[ALLOC2]]
|
||||||
// CHECK-NEXT: memref.dealloc %[[ALLOC2]]
|
// CHECK-NEXT: memref.dealloc %[[ALLOC2]]
|
||||||
// CHECK-NEXT: br ^bb3(%[[ALLOC3]]
|
// CHECK-NEXT: cf.br ^bb3(%[[ALLOC3]]
|
||||||
// CHECK-NEXT: ^bb3(%[[ALLOC4:.*]]:{{.*}})
|
// CHECK-NEXT: ^bb3(%[[ALLOC4:.*]]:{{.*}})
|
||||||
// CHECK: test.copy
|
// CHECK: test.copy
|
||||||
// CHECK-NEXT: memref.dealloc %[[ALLOC4]]
|
// CHECK-NEXT: memref.dealloc %[[ALLOC4]]
|
||||||
|
@ -480,20 +480,20 @@ func @moving_invalid_dealloc_op_complex(
|
||||||
%arg0: memref<2xf32>,
|
%arg0: memref<2xf32>,
|
||||||
%arg1: memref<2xf32>) {
|
%arg1: memref<2xf32>) {
|
||||||
%1 = memref.alloc() : memref<2xf32>
|
%1 = memref.alloc() : memref<2xf32>
|
||||||
cond_br %cond, ^bb1, ^bb2
|
cf.cond_br %cond, ^bb1, ^bb2
|
||||||
^bb1:
|
^bb1:
|
||||||
br ^exit(%arg0 : memref<2xf32>)
|
cf.br ^exit(%arg0 : memref<2xf32>)
|
||||||
^bb2:
|
^bb2:
|
||||||
test.buffer_based in(%arg0: memref<2xf32>) out(%1: memref<2xf32>)
|
test.buffer_based in(%arg0: memref<2xf32>) out(%1: memref<2xf32>)
|
||||||
memref.dealloc %1 : memref<2xf32>
|
memref.dealloc %1 : memref<2xf32>
|
||||||
br ^exit(%1 : memref<2xf32>)
|
cf.br ^exit(%1 : memref<2xf32>)
|
||||||
^exit(%arg2: memref<2xf32>):
|
^exit(%arg2: memref<2xf32>):
|
||||||
test.copy(%arg2, %arg1) : (memref<2xf32>, memref<2xf32>)
|
test.copy(%arg2, %arg1) : (memref<2xf32>, memref<2xf32>)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-NEXT: %[[ALLOC0:.*]] = memref.alloc()
|
// CHECK-NEXT: %[[ALLOC0:.*]] = memref.alloc()
|
||||||
// CHECK-NEXT: cond_br
|
// CHECK-NEXT: cf.cond_br
|
||||||
// CHECK: test.copy
|
// CHECK: test.copy
|
||||||
// CHECK-NEXT: memref.dealloc %[[ALLOC0]]
|
// CHECK-NEXT: memref.dealloc %[[ALLOC0]]
|
||||||
// CHECK-NEXT: return
|
// CHECK-NEXT: return
|
||||||
|
@ -548,9 +548,9 @@ func @nested_regions_and_cond_branch(
|
||||||
%arg0: i1,
|
%arg0: i1,
|
||||||
%arg1: memref<2xf32>,
|
%arg1: memref<2xf32>,
|
||||||
%arg2: memref<2xf32>) {
|
%arg2: memref<2xf32>) {
|
||||||
cond_br %arg0, ^bb1, ^bb2
|
cf.cond_br %arg0, ^bb1, ^bb2
|
||||||
^bb1:
|
^bb1:
|
||||||
br ^bb3(%arg1 : memref<2xf32>)
|
cf.br ^bb3(%arg1 : memref<2xf32>)
|
||||||
^bb2:
|
^bb2:
|
||||||
%0 = memref.alloc() : memref<2xf32>
|
%0 = memref.alloc() : memref<2xf32>
|
||||||
test.region_buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>) {
|
test.region_buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>) {
|
||||||
|
@ -560,13 +560,13 @@ func @nested_regions_and_cond_branch(
|
||||||
%tmp1 = math.exp %gen1_arg0 : f32
|
%tmp1 = math.exp %gen1_arg0 : f32
|
||||||
test.region_yield %tmp1 : f32
|
test.region_yield %tmp1 : f32
|
||||||
}
|
}
|
||||||
br ^bb3(%0 : memref<2xf32>)
|
cf.br ^bb3(%0 : memref<2xf32>)
|
||||||
^bb3(%1: memref<2xf32>):
|
^bb3(%1: memref<2xf32>):
|
||||||
test.copy(%1, %arg2) : (memref<2xf32>, memref<2xf32>)
|
test.copy(%1, %arg2) : (memref<2xf32>, memref<2xf32>)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// CHECK: (%[[cond:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %{{.*}}: {{.*}})
|
// CHECK: (%[[cond:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %{{.*}}: {{.*}})
|
||||||
// CHECK-NEXT: cond_br %[[cond]], ^[[BB1:.*]], ^[[BB2:.*]]
|
// CHECK-NEXT: cf.cond_br %[[cond]], ^[[BB1:.*]], ^[[BB2:.*]]
|
||||||
// CHECK: %[[ALLOC0:.*]] = bufferization.clone %[[ARG1]]
|
// CHECK: %[[ALLOC0:.*]] = bufferization.clone %[[ARG1]]
|
||||||
// CHECK: ^[[BB2]]:
|
// CHECK: ^[[BB2]]:
|
||||||
// CHECK: %[[ALLOC1:.*]] = memref.alloc()
|
// CHECK: %[[ALLOC1:.*]] = memref.alloc()
|
||||||
|
@ -728,21 +728,21 @@ func @subview(%arg0 : index, %arg1 : index, %arg2 : memref<?x?xf32>) {
|
||||||
|
|
||||||
// CHECK-LABEL: func @condBranchAlloca
|
// CHECK-LABEL: func @condBranchAlloca
|
||||||
func @condBranchAlloca(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
|
func @condBranchAlloca(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
|
||||||
cond_br %arg0, ^bb1, ^bb2
|
cf.cond_br %arg0, ^bb1, ^bb2
|
||||||
^bb1:
|
^bb1:
|
||||||
br ^bb3(%arg1 : memref<2xf32>)
|
cf.br ^bb3(%arg1 : memref<2xf32>)
|
||||||
^bb2:
|
^bb2:
|
||||||
%0 = memref.alloca() : memref<2xf32>
|
%0 = memref.alloca() : memref<2xf32>
|
||||||
test.buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>)
|
test.buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>)
|
||||||
br ^bb3(%0 : memref<2xf32>)
|
cf.br ^bb3(%0 : memref<2xf32>)
|
||||||
^bb3(%1: memref<2xf32>):
|
^bb3(%1: memref<2xf32>):
|
||||||
test.copy(%1, %arg2) : (memref<2xf32>, memref<2xf32>)
|
test.copy(%1, %arg2) : (memref<2xf32>, memref<2xf32>)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-NEXT: cond_br
|
// CHECK-NEXT: cf.cond_br
|
||||||
// CHECK: %[[ALLOCA:.*]] = memref.alloca()
|
// CHECK: %[[ALLOCA:.*]] = memref.alloca()
|
||||||
// CHECK: br ^bb3(%[[ALLOCA:.*]])
|
// CHECK: cf.br ^bb3(%[[ALLOCA:.*]])
|
||||||
// CHECK-NEXT: ^bb3
|
// CHECK-NEXT: ^bb3
|
||||||
// CHECK-NEXT: test.copy
|
// CHECK-NEXT: test.copy
|
||||||
// CHECK-NEXT: return
|
// CHECK-NEXT: return
|
||||||
|
@ -757,13 +757,13 @@ func @condBranchAlloca(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
|
||||||
func @ifElseAlloca(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
|
func @ifElseAlloca(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
|
||||||
%0 = memref.alloc() : memref<2xf32>
|
%0 = memref.alloc() : memref<2xf32>
|
||||||
test.buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>)
|
test.buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>)
|
||||||
cond_br %arg0,
|
cf.cond_br %arg0,
|
||||||
^bb1(%arg1, %0 : memref<2xf32>, memref<2xf32>),
|
^bb1(%arg1, %0 : memref<2xf32>, memref<2xf32>),
|
||||||
^bb2(%0, %arg1 : memref<2xf32>, memref<2xf32>)
|
^bb2(%0, %arg1 : memref<2xf32>, memref<2xf32>)
|
||||||
^bb1(%1: memref<2xf32>, %2: memref<2xf32>):
|
^bb1(%1: memref<2xf32>, %2: memref<2xf32>):
|
||||||
br ^bb3(%1, %2 : memref<2xf32>, memref<2xf32>)
|
cf.br ^bb3(%1, %2 : memref<2xf32>, memref<2xf32>)
|
||||||
^bb2(%3: memref<2xf32>, %4: memref<2xf32>):
|
^bb2(%3: memref<2xf32>, %4: memref<2xf32>):
|
||||||
br ^bb3(%3, %4 : memref<2xf32>, memref<2xf32>)
|
cf.br ^bb3(%3, %4 : memref<2xf32>, memref<2xf32>)
|
||||||
^bb3(%5: memref<2xf32>, %6: memref<2xf32>):
|
^bb3(%5: memref<2xf32>, %6: memref<2xf32>):
|
||||||
%7 = memref.alloca() : memref<2xf32>
|
%7 = memref.alloca() : memref<2xf32>
|
||||||
test.buffer_based in(%5: memref<2xf32>) out(%7: memref<2xf32>)
|
test.buffer_based in(%5: memref<2xf32>) out(%7: memref<2xf32>)
|
||||||
|
@ -788,17 +788,17 @@ func @ifElseNestedAlloca(
|
||||||
%arg2: memref<2xf32>) {
|
%arg2: memref<2xf32>) {
|
||||||
%0 = memref.alloca() : memref<2xf32>
|
%0 = memref.alloca() : memref<2xf32>
|
||||||
test.buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>)
|
test.buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>)
|
||||||
cond_br %arg0,
|
cf.cond_br %arg0,
|
||||||
^bb1(%arg1, %0 : memref<2xf32>, memref<2xf32>),
|
^bb1(%arg1, %0 : memref<2xf32>, memref<2xf32>),
|
||||||
^bb2(%0, %arg1 : memref<2xf32>, memref<2xf32>)
|
^bb2(%0, %arg1 : memref<2xf32>, memref<2xf32>)
|
||||||
^bb1(%1: memref<2xf32>, %2: memref<2xf32>):
|
^bb1(%1: memref<2xf32>, %2: memref<2xf32>):
|
||||||
br ^bb5(%1, %2 : memref<2xf32>, memref<2xf32>)
|
cf.br ^bb5(%1, %2 : memref<2xf32>, memref<2xf32>)
|
||||||
^bb2(%3: memref<2xf32>, %4: memref<2xf32>):
|
^bb2(%3: memref<2xf32>, %4: memref<2xf32>):
|
||||||
cond_br %arg0, ^bb3(%3 : memref<2xf32>), ^bb4(%4 : memref<2xf32>)
|
cf.cond_br %arg0, ^bb3(%3 : memref<2xf32>), ^bb4(%4 : memref<2xf32>)
|
||||||
^bb3(%5: memref<2xf32>):
|
^bb3(%5: memref<2xf32>):
|
||||||
br ^bb5(%5, %3 : memref<2xf32>, memref<2xf32>)
|
cf.br ^bb5(%5, %3 : memref<2xf32>, memref<2xf32>)
|
||||||
^bb4(%6: memref<2xf32>):
|
^bb4(%6: memref<2xf32>):
|
||||||
br ^bb5(%3, %6 : memref<2xf32>, memref<2xf32>)
|
cf.br ^bb5(%3, %6 : memref<2xf32>, memref<2xf32>)
|
||||||
^bb5(%7: memref<2xf32>, %8: memref<2xf32>):
|
^bb5(%7: memref<2xf32>, %8: memref<2xf32>):
|
||||||
%9 = memref.alloc() : memref<2xf32>
|
%9 = memref.alloc() : memref<2xf32>
|
||||||
test.buffer_based in(%7: memref<2xf32>) out(%9: memref<2xf32>)
|
test.buffer_based in(%7: memref<2xf32>) out(%9: memref<2xf32>)
|
||||||
|
@ -821,9 +821,9 @@ func @nestedRegionsAndCondBranchAlloca(
|
||||||
%arg0: i1,
|
%arg0: i1,
|
||||||
%arg1: memref<2xf32>,
|
%arg1: memref<2xf32>,
|
||||||
%arg2: memref<2xf32>) {
|
%arg2: memref<2xf32>) {
|
||||||
cond_br %arg0, ^bb1, ^bb2
|
cf.cond_br %arg0, ^bb1, ^bb2
|
||||||
^bb1:
|
^bb1:
|
||||||
br ^bb3(%arg1 : memref<2xf32>)
|
cf.br ^bb3(%arg1 : memref<2xf32>)
|
||||||
^bb2:
|
^bb2:
|
||||||
%0 = memref.alloc() : memref<2xf32>
|
%0 = memref.alloc() : memref<2xf32>
|
||||||
test.region_buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>) {
|
test.region_buffer_based in(%arg1: memref<2xf32>) out(%0: memref<2xf32>) {
|
||||||
|
@ -833,13 +833,13 @@ func @nestedRegionsAndCondBranchAlloca(
|
||||||
%tmp1 = math.exp %gen1_arg0 : f32
|
%tmp1 = math.exp %gen1_arg0 : f32
|
||||||
test.region_yield %tmp1 : f32
|
test.region_yield %tmp1 : f32
|
||||||
}
|
}
|
||||||
br ^bb3(%0 : memref<2xf32>)
|
cf.br ^bb3(%0 : memref<2xf32>)
|
||||||
^bb3(%1: memref<2xf32>):
|
^bb3(%1: memref<2xf32>):
|
||||||
test.copy(%1, %arg2) : (memref<2xf32>, memref<2xf32>)
|
test.copy(%1, %arg2) : (memref<2xf32>, memref<2xf32>)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// CHECK: (%[[cond:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %{{.*}}: {{.*}})
|
// CHECK: (%[[cond:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %{{.*}}: {{.*}})
|
||||||
// CHECK-NEXT: cond_br %[[cond]], ^[[BB1:.*]], ^[[BB2:.*]]
|
// CHECK-NEXT: cf.cond_br %[[cond]], ^[[BB1:.*]], ^[[BB2:.*]]
|
||||||
// CHECK: ^[[BB1]]:
|
// CHECK: ^[[BB1]]:
|
||||||
// CHECK: %[[ALLOC0:.*]] = bufferization.clone
|
// CHECK: %[[ALLOC0:.*]] = bufferization.clone
|
||||||
// CHECK: ^[[BB2]]:
|
// CHECK: ^[[BB2]]:
|
||||||
|
@ -1103,11 +1103,11 @@ func @loop_dynalloc(
|
||||||
%arg2: memref<?xf32>,
|
%arg2: memref<?xf32>,
|
||||||
%arg3: memref<?xf32>) {
|
%arg3: memref<?xf32>) {
|
||||||
%const0 = arith.constant 0 : i32
|
%const0 = arith.constant 0 : i32
|
||||||
br ^loopHeader(%const0, %arg2 : i32, memref<?xf32>)
|
cf.br ^loopHeader(%const0, %arg2 : i32, memref<?xf32>)
|
||||||
|
|
||||||
^loopHeader(%i : i32, %buff : memref<?xf32>):
|
^loopHeader(%i : i32, %buff : memref<?xf32>):
|
||||||
%lessThan = arith.cmpi slt, %i, %arg1 : i32
|
%lessThan = arith.cmpi slt, %i, %arg1 : i32
|
||||||
cond_br %lessThan,
|
cf.cond_br %lessThan,
|
||||||
^loopBody(%i, %buff : i32, memref<?xf32>),
|
^loopBody(%i, %buff : i32, memref<?xf32>),
|
||||||
^exit(%buff : memref<?xf32>)
|
^exit(%buff : memref<?xf32>)
|
||||||
|
|
||||||
|
@ -1116,7 +1116,7 @@ func @loop_dynalloc(
|
||||||
%inc = arith.addi %val, %const1 : i32
|
%inc = arith.addi %val, %const1 : i32
|
||||||
%size = arith.index_cast %inc : i32 to index
|
%size = arith.index_cast %inc : i32 to index
|
||||||
%alloc1 = memref.alloc(%size) : memref<?xf32>
|
%alloc1 = memref.alloc(%size) : memref<?xf32>
|
||||||
br ^loopHeader(%inc, %alloc1 : i32, memref<?xf32>)
|
cf.br ^loopHeader(%inc, %alloc1 : i32, memref<?xf32>)
|
||||||
|
|
||||||
^exit(%buff3 : memref<?xf32>):
|
^exit(%buff3 : memref<?xf32>):
|
||||||
test.copy(%buff3, %arg3) : (memref<?xf32>, memref<?xf32>)
|
test.copy(%buff3, %arg3) : (memref<?xf32>, memref<?xf32>)
|
||||||
|
@ -1136,17 +1136,17 @@ func @do_loop_alloc(
|
||||||
%arg2: memref<2xf32>,
|
%arg2: memref<2xf32>,
|
||||||
%arg3: memref<2xf32>) {
|
%arg3: memref<2xf32>) {
|
||||||
%const0 = arith.constant 0 : i32
|
%const0 = arith.constant 0 : i32
|
||||||
br ^loopBody(%const0, %arg2 : i32, memref<2xf32>)
|
cf.br ^loopBody(%const0, %arg2 : i32, memref<2xf32>)
|
||||||
|
|
||||||
^loopBody(%val : i32, %buff2: memref<2xf32>):
|
^loopBody(%val : i32, %buff2: memref<2xf32>):
|
||||||
%const1 = arith.constant 1 : i32
|
%const1 = arith.constant 1 : i32
|
||||||
%inc = arith.addi %val, %const1 : i32
|
%inc = arith.addi %val, %const1 : i32
|
||||||
%alloc1 = memref.alloc() : memref<2xf32>
|
%alloc1 = memref.alloc() : memref<2xf32>
|
||||||
br ^loopHeader(%inc, %alloc1 : i32, memref<2xf32>)
|
cf.br ^loopHeader(%inc, %alloc1 : i32, memref<2xf32>)
|
||||||
|
|
||||||
^loopHeader(%i : i32, %buff : memref<2xf32>):
|
^loopHeader(%i : i32, %buff : memref<2xf32>):
|
||||||
%lessThan = arith.cmpi slt, %i, %arg1 : i32
|
%lessThan = arith.cmpi slt, %i, %arg1 : i32
|
||||||
cond_br %lessThan,
|
cf.cond_br %lessThan,
|
||||||
^loopBody(%i, %buff : i32, memref<2xf32>),
|
^loopBody(%i, %buff : i32, memref<2xf32>),
|
||||||
^exit(%buff : memref<2xf32>)
|
^exit(%buff : memref<2xf32>)
|
||||||
|
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue