Rewrite MLPatternLoweringPass to no longer inherit from FunctionPass and just provide a utility function that applies ML patterns.

PiperOrigin-RevId: 235194034
This commit is contained in:
River Riddle 2019-02-22 08:10:10 -08:00 committed by jpienaar
parent 8564b274db
commit 5410dff790
2 changed files with 28 additions and 49 deletions

View File

@ -81,41 +81,7 @@ namespace detail {
/// Owning list of ML lowering patterns.
using OwningMLLoweringPatternList =
std::vector<std::unique_ptr<mlir::MLLoweringPattern>>;
} // namespace detail
/// Generic lowering pass for ML functions. The lowering details are defined as
/// a sequence of pattern matchers. The following constraints on matchers
/// apply:
/// - only one (match root) operation can be removed;
/// - the code produced by rewriters is final, it is not pattern-matched;
/// - the matchers are applied in their order of appearance in the list;
/// - if the match is found, the operation is rewritten immediately and the
/// next _original_ operation is considered.
/// In other words, for each operation, the pass applies the first matching
/// rewriter in the list and advances to the (lexically) next operation.
/// Non-operation instructions (ForInst) are ignored.
/// This is similar to greedy worklist-based pattern rewriter, except that this
/// operates on ML functions using an ML builder and does not maintain the work
/// list. Note that, as of the time of writing, worklist-based rewriter did not
/// support removing multiple operations either.
template <typename... Patterns>
class MLPatternLoweringPass : public FunctionPass {
public:
explicit MLPatternLoweringPass(const PassID *ID) : FunctionPass(ID) {}
virtual std::unique_ptr<MLFuncGlobalLoweringState>
makeFuncWiseState(Function *f) const {
return nullptr;
}
PassResult runOnFunction(Function *f) override;
};
/////////////////////////////////////////////////////////////////////
// MLPatternLoweringPass template implementations
/////////////////////////////////////////////////////////////////////
namespace detail {
template <typename Pattern, typename... Patterns> struct ListAdder {
static void addPatternsToList(OwningMLLoweringPatternList *list,
MLIRContext *context) {
@ -134,11 +100,25 @@ template <typename Pattern> struct ListAdder<Pattern> {
};
} // namespace detail
/// Generic lowering for ML patterns. The lowering details are defined as
/// a sequence of pattern matchers. The following constraints on matchers
/// apply:
/// - only one (match root) operation can be removed;
/// - the code produced by rewriters is final, it is not pattern-matched;
/// - the matchers are applied in their order of appearance in the list;
/// - if the match is found, the operation is rewritten immediately and the
/// next _original_ operation is considered.
/// In other words, for each operation, apply the first matching rewriter in the
/// list and advance to the (lexically) next operation. This is similar to
/// greedy worklist-based pattern rewriter, except that this operates on ML
/// functions using an ML builder and does not maintain the work list. Note
/// that, as of the time of writing, worklist-based rewriter did not support
/// removing multiple operations either.
template <typename... Patterns>
PassResult MLPatternLoweringPass<Patterns...>::runOnFunction(Function *f) {
void applyMLPatternsGreedily(
Function *f, MLFuncGlobalLoweringState *funcWiseState = nullptr) {
detail::OwningMLLoweringPatternList patterns;
detail::ListAdder<Patterns...>::addPatternsToList(&patterns, f->getContext());
auto funcWiseState = makeFuncWiseState(f);
FuncBuilder builder(f);
MLFuncLoweringRewriter rewriter(&builder);
@ -148,19 +128,15 @@ PassResult MLPatternLoweringPass<Patterns...>::runOnFunction(Function *f) {
for (Instruction *inst : ops) {
for (const auto &pattern : patterns) {
rewriter.getBuilder()->setInsertionPoint(inst);
auto matchResult = pattern->match(inst);
if (matchResult) {
pattern->rewriteOpInst(inst, funcWiseState.get(),
std::move(*matchResult), &rewriter);
builder.setInsertionPoint(inst);
if (auto matchResult = pattern->match(inst)) {
pattern->rewriteOpInst(inst, funcWiseState, std::move(*matchResult),
&rewriter);
break;
}
}
}
return PassResult::Success;
}
} // end namespace mlir
#endif // MLIR_TRANSFORMS_MLPATTERNLOWERINGPASS_H

View File

@ -424,12 +424,15 @@ public:
}
};
struct LowerVectorTransfersPass
: public MLPatternLoweringPass<
VectorTransferExpander<VectorTransferReadOp>,
VectorTransferExpander<VectorTransferWriteOp>> {
struct LowerVectorTransfersPass : public FunctionPass {
LowerVectorTransfersPass()
: MLPatternLoweringPass(&LowerVectorTransfersPass::passID) {}
: FunctionPass(&LowerVectorTransfersPass::passID) {}
PassResult runOnFunction(Function *fn) override {
applyMLPatternsGreedily<VectorTransferExpander<VectorTransferReadOp>,
VectorTransferExpander<VectorTransferWriteOp>>(fn);
return success();
}
// Thread-safe RAII context with local scope. BumpPtrAllocator freed on exit.
edsc::ScopedEDSCContext raiiContext;