Make "LowerToCFG" an operation pass

The conversion from the Loops dialect to the Standard dialect, also known as
loop-to-cfg lowering, has been historically a function pass. It can be required
on non-Standard function Ops, in particular the recently introduced GPU
functions. Make the conversion an operation pass instead of a function pass.

PiperOrigin-RevId: 285814560
This commit is contained in:
Alex Zinenko 2019-12-16 11:35:29 -08:00 committed by A. Unique TensorFlower
parent 3ae56c4135
commit ed749b7689
2 changed files with 7 additions and 8 deletions

View File

@ -22,10 +22,9 @@
#include <vector>
namespace mlir {
class FuncOp;
struct LogicalResult;
class MLIRContext;
template <typename T> class OpPassBase;
class Pass;
class RewritePattern;
// Owning list of rewriting patterns.
@ -38,7 +37,7 @@ void populateLoopToStdConversionPatterns(OwningRewritePatternList &patterns,
MLIRContext *ctx);
/// Creates a pass to convert loop.for, loop.if and loop.terminator ops to CFG.
std::unique_ptr<OpPassBase<FuncOp>> createLowerToCFGPass();
std::unique_ptr<Pass> createLowerToCFGPass();
} // namespace mlir

View File

@ -38,8 +38,8 @@ using namespace mlir::loop;
namespace {
struct LoopToStandardPass : public FunctionPass<LoopToStandardPass> {
void runOnFunction() override;
struct LoopToStandardPass : public OperationPass<LoopToStandardPass> {
void runOnOperation() override;
};
// Create a CFG subgraph for the loop around its body blocks (if the body
@ -261,16 +261,16 @@ void mlir::populateLoopToStdConversionPatterns(
patterns.insert<ForLowering, IfLowering, TerminatorLowering>(ctx);
}
void LoopToStandardPass::runOnFunction() {
void LoopToStandardPass::runOnOperation() {
OwningRewritePatternList patterns;
populateLoopToStdConversionPatterns(patterns, &getContext());
ConversionTarget target(getContext());
target.addLegalDialect<StandardOpsDialect>();
if (failed(applyPartialConversion(getFunction(), target, patterns)))
if (failed(applyPartialConversion(getOperation(), target, patterns)))
signalPassFailure();
}
std::unique_ptr<OpPassBase<FuncOp>> mlir::createLowerToCFGPass() {
std::unique_ptr<Pass> mlir::createLowerToCFGPass() {
return std::make_unique<LoopToStandardPass>();
}