forked from OSchip/llvm-project
Rewrite MLPatternLoweringPass to no longer inherit from FunctionPass and just provide a utility function that applies ML patterns.
PiperOrigin-RevId: 235194034
This commit is contained in:
parent
8564b274db
commit
5410dff790
|
@ -81,41 +81,7 @@ namespace detail {
|
|||
/// Owning list of ML lowering patterns.
|
||||
using OwningMLLoweringPatternList =
|
||||
std::vector<std::unique_ptr<mlir::MLLoweringPattern>>;
|
||||
} // namespace detail
|
||||
|
||||
/// Generic lowering pass for ML functions. The lowering details are defined as
|
||||
/// a sequence of pattern matchers. The following constraints on matchers
|
||||
/// apply:
|
||||
/// - only one (match root) operation can be removed;
|
||||
/// - the code produced by rewriters is final, it is not pattern-matched;
|
||||
/// - the matchers are applied in their order of appearance in the list;
|
||||
/// - if the match is found, the operation is rewritten immediately and the
|
||||
/// next _original_ operation is considered.
|
||||
/// In other words, for each operation, the pass applies the first matching
|
||||
/// rewriter in the list and advances to the (lexically) next operation.
|
||||
/// Non-operation instructions (ForInst) are ignored.
|
||||
/// This is similar to greedy worklist-based pattern rewriter, except that this
|
||||
/// operates on ML functions using an ML builder and does not maintain the work
|
||||
/// list. Note that, as of the time of writing, worklist-based rewriter did not
|
||||
/// support removing multiple operations either.
|
||||
template <typename... Patterns>
|
||||
class MLPatternLoweringPass : public FunctionPass {
|
||||
public:
|
||||
explicit MLPatternLoweringPass(const PassID *ID) : FunctionPass(ID) {}
|
||||
|
||||
virtual std::unique_ptr<MLFuncGlobalLoweringState>
|
||||
makeFuncWiseState(Function *f) const {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
PassResult runOnFunction(Function *f) override;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
// MLPatternLoweringPass template implementations
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace detail {
|
||||
template <typename Pattern, typename... Patterns> struct ListAdder {
|
||||
static void addPatternsToList(OwningMLLoweringPatternList *list,
|
||||
MLIRContext *context) {
|
||||
|
@ -134,11 +100,25 @@ template <typename Pattern> struct ListAdder<Pattern> {
|
|||
};
|
||||
} // namespace detail
|
||||
|
||||
/// Generic lowering for ML patterns. The lowering details are defined as
|
||||
/// a sequence of pattern matchers. The following constraints on matchers
|
||||
/// apply:
|
||||
/// - only one (match root) operation can be removed;
|
||||
/// - the code produced by rewriters is final, it is not pattern-matched;
|
||||
/// - the matchers are applied in their order of appearance in the list;
|
||||
/// - if the match is found, the operation is rewritten immediately and the
|
||||
/// next _original_ operation is considered.
|
||||
/// In other words, for each operation, apply the first matching rewriter in the
|
||||
/// list and advance to the (lexically) next operation. This is similar to
|
||||
/// greedy worklist-based pattern rewriter, except that this operates on ML
|
||||
/// functions using an ML builder and does not maintain the work list. Note
|
||||
/// that, as of the time of writing, worklist-based rewriter did not support
|
||||
/// removing multiple operations either.
|
||||
template <typename... Patterns>
|
||||
PassResult MLPatternLoweringPass<Patterns...>::runOnFunction(Function *f) {
|
||||
void applyMLPatternsGreedily(
|
||||
Function *f, MLFuncGlobalLoweringState *funcWiseState = nullptr) {
|
||||
detail::OwningMLLoweringPatternList patterns;
|
||||
detail::ListAdder<Patterns...>::addPatternsToList(&patterns, f->getContext());
|
||||
auto funcWiseState = makeFuncWiseState(f);
|
||||
|
||||
FuncBuilder builder(f);
|
||||
MLFuncLoweringRewriter rewriter(&builder);
|
||||
|
@ -148,19 +128,15 @@ PassResult MLPatternLoweringPass<Patterns...>::runOnFunction(Function *f) {
|
|||
|
||||
for (Instruction *inst : ops) {
|
||||
for (const auto &pattern : patterns) {
|
||||
rewriter.getBuilder()->setInsertionPoint(inst);
|
||||
auto matchResult = pattern->match(inst);
|
||||
if (matchResult) {
|
||||
pattern->rewriteOpInst(inst, funcWiseState.get(),
|
||||
std::move(*matchResult), &rewriter);
|
||||
builder.setInsertionPoint(inst);
|
||||
if (auto matchResult = pattern->match(inst)) {
|
||||
pattern->rewriteOpInst(inst, funcWiseState, std::move(*matchResult),
|
||||
&rewriter);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return PassResult::Success;
|
||||
}
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_TRANSFORMS_MLPATTERNLOWERINGPASS_H
|
||||
|
|
|
@ -424,12 +424,15 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
struct LowerVectorTransfersPass
|
||||
: public MLPatternLoweringPass<
|
||||
VectorTransferExpander<VectorTransferReadOp>,
|
||||
VectorTransferExpander<VectorTransferWriteOp>> {
|
||||
struct LowerVectorTransfersPass : public FunctionPass {
|
||||
LowerVectorTransfersPass()
|
||||
: MLPatternLoweringPass(&LowerVectorTransfersPass::passID) {}
|
||||
: FunctionPass(&LowerVectorTransfersPass::passID) {}
|
||||
|
||||
PassResult runOnFunction(Function *fn) override {
|
||||
applyMLPatternsGreedily<VectorTransferExpander<VectorTransferReadOp>,
|
||||
VectorTransferExpander<VectorTransferWriteOp>>(fn);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Thread-safe RAII context with local scope. BumpPtrAllocator freed on exit.
|
||||
edsc::ScopedEDSCContext raiiContext;
|
||||
|
|
Loading…
Reference in New Issue