Remove MLPatternLoweringPass and rewrite LowerVectorTransfers to use RewritePattern instead.

--

PiperOrigin-RevId: 241455472
This commit is contained in:
River Riddle 2019-04-01 20:43:13 -07:00 committed by Mehdi Amini
parent bae95d25e5
commit 084669e005
3 changed files with 44 additions and 199 deletions

View File

@ -1,142 +0,0 @@
//===- MLPatternLoweringPass.h - Generic ML lowering pass -------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// Defines a generic class to implement lowering passes on ML functions as a
// list of pattern rewriters.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TRANSFORMS_MLPATTERNLOWERINGPASS_H
#define MLIR_TRANSFORMS_MLPATTERNLOWERINGPASS_H
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include <type_traits>
namespace mlir {
/// Specialization of the pattern rewriter to ML functions.
class MLFuncLoweringRewriter : public PatternRewriter {
public:
explicit MLFuncLoweringRewriter(FuncBuilder *builder)
: PatternRewriter(builder->getContext()), builder(builder) {}
FuncBuilder *getBuilder() { return builder; }
Operation *createOperation(const OperationState &state) override {
auto *result = builder->createOperation(state);
return result;
}
private:
FuncBuilder *builder;
};
/// Base class for the Function-wise lowering state. A pointer to the same
/// instance of the subclass will be passed to all `rewrite` calls on operations
/// that belong to the same Function.
class MLFuncGlobalLoweringState {
public:
virtual ~MLFuncGlobalLoweringState() {}
protected:
// Must be subclassed.
MLFuncGlobalLoweringState() {}
};
/// Base class for Function lowering patterns.
class MLLoweringPattern : public Pattern {
public:
/// Subclasses must override this function to implement rewriting. It will be
/// called on all operations found by `match` (declared in Pattern, subclasses
/// must override). It will be passed the function-wise state, common to all
/// matches, and the state returned by the `match` call, if any. The subclass
/// must use `rewriter` to modify the function.
virtual void rewriteOpInst(Operation *op,
MLFuncGlobalLoweringState *funcWiseState,
std::unique_ptr<PatternState> opState,
MLFuncLoweringRewriter *rewriter) const = 0;
protected:
// Must be subclassed.
MLLoweringPattern(StringRef opName, int64_t benefit, MLIRContext *context)
: Pattern(opName, benefit, context) {}
};
namespace detail {
/// Owning list of ML lowering patterns.
using OwningMLLoweringPatternList =
std::vector<std::unique_ptr<mlir::MLLoweringPattern>>;
template <typename Pattern, typename... Patterns> struct ListAdder {
static void addPatternsToList(OwningMLLoweringPatternList *list,
MLIRContext *context) {
static_assert(std::is_base_of<MLLoweringPattern, Pattern>::value,
"can only add subclasses of MLLoweringPattern");
list->emplace_back(new Pattern(context));
ListAdder<Patterns...>::addPatternsToList(list, context);
}
};
template <typename Pattern> struct ListAdder<Pattern> {
static void addPatternsToList(OwningMLLoweringPatternList *list,
MLIRContext *context) {
list->emplace_back(new Pattern(context));
}
};
} // namespace detail
/// Generic lowering for ML patterns. The lowering details are defined as
/// a sequence of pattern matchers. The following constraints on matchers
/// apply:
/// - only one (match root) operation can be removed;
/// - the code produced by rewriters is final, it is not pattern-matched;
/// - the matchers are applied in their order of appearance in the list;
/// - if the match is found, the operation is rewritten immediately and the
/// next _original_ operation is considered.
/// In other words, for each operation, apply the first matching rewriter in the
/// list and advance to the (lexically) next operation. This is similar to
/// greedy worklist-based pattern rewriter, except that this operates on ML
/// functions using an ML builder and does not maintain the work list. Note
/// that, as of the time of writing, worklist-based rewriter did not support
/// removing multiple operations either.
template <typename... Patterns>
void applyMLPatternsGreedily(
Function *f, MLFuncGlobalLoweringState *funcWiseState = nullptr) {
detail::OwningMLLoweringPatternList patterns;
detail::ListAdder<Patterns...>::addPatternsToList(&patterns, f->getContext());
FuncBuilder builder(f);
MLFuncLoweringRewriter rewriter(&builder);
llvm::SmallVector<Operation *, 16> ops;
f->walk([&ops](Operation *op) { ops.push_back(op); });
for (Operation *op : ops) {
for (const auto &pattern : patterns) {
builder.setInsertionPoint(op);
if (auto matchResult = pattern->match(op)) {
pattern->rewriteOpInst(op, funcWiseState, std::move(*matchResult),
&rewriter);
break;
}
}
}
}
} // end namespace mlir
#endif // MLIR_TRANSFORMS_MLPATTERNLOWERINGPASS_H

View File

@ -39,7 +39,6 @@
#include "mlir/Pass/Pass.h"
#include "mlir/StandardOps/Ops.h"
#include "mlir/Support/Functional.h"
#include "mlir/Transforms/MLPatternLoweringPass.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/VectorOps/VectorOps.h"
@ -99,30 +98,27 @@ namespace {
/// 4. local memory deallocation.
/// Minor variations occur depending on whether a VectorTransferReadOp or
/// a VectorTransferWriteOp is rewritten.
template <typename VectorTransferOpTy> class VectorTransferRewriter {
public:
VectorTransferRewriter(VectorTransferOpTy transfer,
MLFuncLoweringRewriter *rewriter,
MLFuncGlobalLoweringState *state);
template <typename VectorTransferOpTy>
struct VectorTransferRewriter : public RewritePattern {
explicit VectorTransferRewriter(MLIRContext *context)
: RewritePattern(VectorTransferOpTy::getOperationName(), 1, context) {}
/// Used for staging the transfer in a local scalar buffer.
MemRefType tmpMemRefType() {
MemRefType tmpMemRefType(VectorTransferOpTy transfer) const {
auto vectorType = transfer.getVectorType();
return MemRefType::get(vectorType.getShape(), vectorType.getElementType(),
{}, 0);
}
/// View of tmpMemRefType as one vector, used in vector load/store to tmp
/// buffer.
MemRefType vectorMemRefType() {
MemRefType vectorMemRefType(VectorTransferOpTy transfer) const {
return MemRefType::get({1}, transfer.getVectorType(), {}, 0);
}
/// Performs the rewrite.
void rewrite();
private:
VectorTransferOpTy transfer;
MLFuncLoweringRewriter *rewriter;
MLFuncGlobalLoweringState *state;
/// Performs the rewrite.
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override;
};
} // end anonymous namespace
@ -213,12 +209,6 @@ clip(VectorTransferOpTy transfer, edsc::MemRefView &view,
return clippedScalarAccessExprs;
}
template <typename VectorTransferOpTy>
VectorTransferRewriter<VectorTransferOpTy>::VectorTransferRewriter(
VectorTransferOpTy transfer, MLFuncLoweringRewriter *rewriter,
MLFuncGlobalLoweringState *state)
: transfer(transfer), rewriter(rewriter), state(state){};
/// Lowers VectorTransferReadOp into a combination of:
/// 1. local memory allocation;
/// 2. perfect loop nest over:
@ -260,13 +250,20 @@ VectorTransferRewriter<VectorTransferOpTy>::VectorTransferRewriter(
///
/// TODO(ntv): implement alternatives to clipping.
/// TODO(ntv): support non-data-parallel operations.
template <> void VectorTransferRewriter<VectorTransferReadOp>::rewrite() {
/// Performs the rewrite.
template <>
PatternMatchResult
VectorTransferRewriter<VectorTransferReadOp>::matchAndRewrite(
Operation *op, PatternRewriter &rewriter) const {
using namespace mlir::edsc;
using namespace mlir::edsc::op;
using namespace mlir::edsc::intrinsics;
VectorTransferReadOp transfer = op->cast<VectorTransferReadOp>();
// 1. Setup all the captures.
ScopedContext scope(FuncBuilder(transfer.getOperation()), transfer.getLoc());
ScopedContext scope(FuncBuilder(op), transfer.getLoc());
IndexedValue remote(transfer.getMemRef());
MemRefView view(transfer.getMemRef());
VectorView vectorView(transfer.getVector());
@ -281,9 +278,9 @@ template <> void VectorTransferRewriter<VectorTransferReadOp>::rewrite() {
auto steps = vectorView.getSteps();
// 2. Emit alloc-copy-load-dealloc.
ValueHandle tmp = alloc(tmpMemRefType());
ValueHandle tmp = alloc(tmpMemRefType(transfer));
IndexedValue local(tmp);
ValueHandle vec = vector_type_cast(tmp, vectorMemRefType());
ValueHandle vec = vector_type_cast(tmp, vectorMemRefType(transfer));
LoopNestBuilder(pivs, lbs, ubs, steps)({
// Computes clippedScalarAccessExprs in the loop nest scope (ivs exist).
local(ivs) = remote(clip(transfer, view, ivs)),
@ -292,8 +289,8 @@ template <> void VectorTransferRewriter<VectorTransferReadOp>::rewrite() {
(dealloc(tmp)); // vexing parse
// 3. Propagate.
transfer.replaceAllUsesWith(vectorValue.getValue());
transfer.erase();
rewriter.replaceOp(op, vectorValue.getValue());
return matchSuccess();
}
/// Lowers VectorTransferWriteOp into a combination of:
@ -314,13 +311,18 @@ template <> void VectorTransferRewriter<VectorTransferReadOp>::rewrite() {
///
/// TODO(ntv): implement alternatives to clipping.
/// TODO(ntv): support non-data-parallel operations.
template <> void VectorTransferRewriter<VectorTransferWriteOp>::rewrite() {
template <>
PatternMatchResult
VectorTransferRewriter<VectorTransferWriteOp>::matchAndRewrite(
Operation *op, PatternRewriter &rewriter) const {
using namespace mlir::edsc;
using namespace mlir::edsc::op;
using namespace mlir::edsc::intrinsics;
VectorTransferWriteOp transfer = op->cast<VectorTransferWriteOp>();
// 1. Setup all the captures.
ScopedContext scope(FuncBuilder(transfer.getOperation()), transfer.getLoc());
ScopedContext scope(FuncBuilder(op), transfer.getLoc());
IndexedValue remote(transfer.getMemRef());
MemRefView view(transfer.getMemRef());
ValueHandle vectorValue(transfer.getVector());
@ -336,9 +338,9 @@ template <> void VectorTransferRewriter<VectorTransferWriteOp>::rewrite() {
auto steps = vectorView.getSteps();
// 2. Emit alloc-store-copy-dealloc.
ValueHandle tmp = alloc(tmpMemRefType());
ValueHandle tmp = alloc(tmpMemRefType(transfer));
IndexedValue local(tmp);
ValueHandle vec = vector_type_cast(tmp, vectorMemRefType());
ValueHandle vec = vector_type_cast(tmp, vectorMemRefType(transfer));
store(vectorValue, vec, {constant_index(0)});
LoopNestBuilder(pivs, lbs, ubs, steps)({
// Computes clippedScalarAccessExprs in the loop nest scope (ivs exist).
@ -346,36 +348,23 @@ template <> void VectorTransferRewriter<VectorTransferWriteOp>::rewrite() {
});
(dealloc(tmp)); // vexing parse...
transfer.erase();
rewriter.replaceOp(op, llvm::None);
return matchSuccess();
}
namespace {
template <typename VectorTransferOpTy>
class VectorTransferExpander : public MLLoweringPattern {
public:
explicit VectorTransferExpander(MLIRContext *context)
: MLLoweringPattern(VectorTransferOpTy::getOperationName(), 1, context) {}
PatternMatchResult match(Operation *op) const override {
if (m_Op<VectorTransferOpTy>().match(op))
return matchSuccess();
return matchFailure();
}
void rewriteOpInst(Operation *op, MLFuncGlobalLoweringState *funcWiseState,
std::unique_ptr<PatternState> opState,
MLFuncLoweringRewriter *rewriter) const override {
VectorTransferRewriter<VectorTransferOpTy>(
op->dyn_cast<VectorTransferOpTy>(), rewriter, funcWiseState)
.rewrite();
}
};
struct LowerVectorTransfersPass
: public FunctionPass<LowerVectorTransfersPass> {
void runOnFunction() {
auto &f = getFunction();
applyMLPatternsGreedily<VectorTransferExpander<VectorTransferReadOp>,
VectorTransferExpander<VectorTransferWriteOp>>(&f);
OwningRewritePatternList patterns;
auto *context = &getContext();
patterns.push_back(
llvm::make_unique<VectorTransferRewriter<VectorTransferReadOp>>(
context));
patterns.push_back(
llvm::make_unique<VectorTransferRewriter<VectorTransferWriteOp>>(
context));
applyPatternsGreedily(getFunction(), std::move(patterns));
}
};
@ -388,5 +377,3 @@ FunctionPassBase *mlir::createLowerVectorTransfersPass() {
static PassRegistration<LowerVectorTransfersPass>
pass("lower-vector-transfers", "Materializes vector transfer ops to a "
"proper abstraction for the hardware");
#undef DEBUG_TYPE

View File

@ -131,8 +131,8 @@ func @materialize_read(%M: index, %N: index, %O: index, %P: index) {
// CHECK-LABEL:func @materialize_write(%arg0: index, %arg1: index, %arg2: index, %arg3: index) {
func @materialize_write(%M: index, %N: index, %O: index, %P: index) {
// CHECK-NEXT: %0 = alloc(%arg0, %arg1, %arg2, %arg3) : memref<?x?x?x?xf32>
// CHECK-NEXT: %cst = constant splat<vector<5x4x3xf32>, 1.000000e+00> : vector<5x4x3xf32>
// CHECK-NEXT: %0 = alloc(%arg0, %arg1, %arg2, %arg3) : memref<?x?x?x?xf32>
// CHECK-NEXT: affine.for %[[I0:.*]] = 0 to %arg0 step 3 {
// CHECK-NEXT: affine.for %[[I1:.*]] = 0 to %arg1 step 4 {
// CHECK-NEXT: affine.for %[[I2:.*]] = 0 to %arg2 {