forked from OSchip/llvm-project
Remove MLPatternLoweringPass and rewrite LowerVectorTransfers to use RewritePattern instead.
-- PiperOrigin-RevId: 241455472
This commit is contained in:
parent
bae95d25e5
commit
084669e005
|
@ -1,142 +0,0 @@
|
|||
//===- MLPatternLoweringPass.h - Generic ML lowering pass -------*- C++ -*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// Defines a generic class to implement lowering passes on ML functions as a
|
||||
// list of pattern rewriters.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_TRANSFORMS_MLPATTERNLOWERINGPASS_H
|
||||
#define MLIR_TRANSFORMS_MLPATTERNLOWERINGPASS_H
|
||||
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include <type_traits>
|
||||
|
||||
namespace mlir {
|
||||
|
||||
/// Specialization of the pattern rewriter to ML functions.
|
||||
class MLFuncLoweringRewriter : public PatternRewriter {
|
||||
public:
|
||||
explicit MLFuncLoweringRewriter(FuncBuilder *builder)
|
||||
: PatternRewriter(builder->getContext()), builder(builder) {}
|
||||
|
||||
FuncBuilder *getBuilder() { return builder; }
|
||||
|
||||
Operation *createOperation(const OperationState &state) override {
|
||||
auto *result = builder->createOperation(state);
|
||||
return result;
|
||||
}
|
||||
|
||||
private:
|
||||
FuncBuilder *builder;
|
||||
};
|
||||
|
||||
/// Base class for the Function-wise lowering state. A pointer to the same
|
||||
/// instance of the subclass will be passed to all `rewrite` calls on operations
|
||||
/// that belong to the same Function.
|
||||
class MLFuncGlobalLoweringState {
|
||||
public:
|
||||
virtual ~MLFuncGlobalLoweringState() {}
|
||||
|
||||
protected:
|
||||
// Must be subclassed.
|
||||
MLFuncGlobalLoweringState() {}
|
||||
};
|
||||
|
||||
/// Base class for Function lowering patterns.
|
||||
class MLLoweringPattern : public Pattern {
|
||||
public:
|
||||
/// Subclasses must override this function to implement rewriting. It will be
|
||||
/// called on all operations found by `match` (declared in Pattern, subclasses
|
||||
/// must override). It will be passed the function-wise state, common to all
|
||||
/// matches, and the state returned by the `match` call, if any. The subclass
|
||||
/// must use `rewriter` to modify the function.
|
||||
virtual void rewriteOpInst(Operation *op,
|
||||
MLFuncGlobalLoweringState *funcWiseState,
|
||||
std::unique_ptr<PatternState> opState,
|
||||
MLFuncLoweringRewriter *rewriter) const = 0;
|
||||
|
||||
protected:
|
||||
// Must be subclassed.
|
||||
MLLoweringPattern(StringRef opName, int64_t benefit, MLIRContext *context)
|
||||
: Pattern(opName, benefit, context) {}
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
/// Owning list of ML lowering patterns.
|
||||
using OwningMLLoweringPatternList =
|
||||
std::vector<std::unique_ptr<mlir::MLLoweringPattern>>;
|
||||
|
||||
template <typename Pattern, typename... Patterns> struct ListAdder {
|
||||
static void addPatternsToList(OwningMLLoweringPatternList *list,
|
||||
MLIRContext *context) {
|
||||
static_assert(std::is_base_of<MLLoweringPattern, Pattern>::value,
|
||||
"can only add subclasses of MLLoweringPattern");
|
||||
list->emplace_back(new Pattern(context));
|
||||
ListAdder<Patterns...>::addPatternsToList(list, context);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Pattern> struct ListAdder<Pattern> {
|
||||
static void addPatternsToList(OwningMLLoweringPatternList *list,
|
||||
MLIRContext *context) {
|
||||
list->emplace_back(new Pattern(context));
|
||||
}
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
/// Generic lowering for ML patterns. The lowering details are defined as
|
||||
/// a sequence of pattern matchers. The following constraints on matchers
|
||||
/// apply:
|
||||
/// - only one (match root) operation can be removed;
|
||||
/// - the code produced by rewriters is final, it is not pattern-matched;
|
||||
/// - the matchers are applied in their order of appearance in the list;
|
||||
/// - if the match is found, the operation is rewritten immediately and the
|
||||
/// next _original_ operation is considered.
|
||||
/// In other words, for each operation, apply the first matching rewriter in the
|
||||
/// list and advance to the (lexically) next operation. This is similar to
|
||||
/// greedy worklist-based pattern rewriter, except that this operates on ML
|
||||
/// functions using an ML builder and does not maintain the work list. Note
|
||||
/// that, as of the time of writing, worklist-based rewriter did not support
|
||||
/// removing multiple operations either.
|
||||
template <typename... Patterns>
|
||||
void applyMLPatternsGreedily(
|
||||
Function *f, MLFuncGlobalLoweringState *funcWiseState = nullptr) {
|
||||
detail::OwningMLLoweringPatternList patterns;
|
||||
detail::ListAdder<Patterns...>::addPatternsToList(&patterns, f->getContext());
|
||||
|
||||
FuncBuilder builder(f);
|
||||
MLFuncLoweringRewriter rewriter(&builder);
|
||||
|
||||
llvm::SmallVector<Operation *, 16> ops;
|
||||
f->walk([&ops](Operation *op) { ops.push_back(op); });
|
||||
|
||||
for (Operation *op : ops) {
|
||||
for (const auto &pattern : patterns) {
|
||||
builder.setInsertionPoint(op);
|
||||
if (auto matchResult = pattern->match(op)) {
|
||||
pattern->rewriteOpInst(op, funcWiseState, std::move(*matchResult),
|
||||
&rewriter);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_TRANSFORMS_MLPATTERNLOWERINGPASS_H
|
|
@ -39,7 +39,6 @@
|
|||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/StandardOps/Ops.h"
|
||||
#include "mlir/Support/Functional.h"
|
||||
#include "mlir/Transforms/MLPatternLoweringPass.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include "mlir/VectorOps/VectorOps.h"
|
||||
|
||||
|
@ -99,30 +98,27 @@ namespace {
|
|||
/// 4. local memory deallocation.
|
||||
/// Minor variations occur depending on whether a VectorTransferReadOp or
|
||||
/// a VectorTransferWriteOp is rewritten.
|
||||
template <typename VectorTransferOpTy> class VectorTransferRewriter {
|
||||
public:
|
||||
VectorTransferRewriter(VectorTransferOpTy transfer,
|
||||
MLFuncLoweringRewriter *rewriter,
|
||||
MLFuncGlobalLoweringState *state);
|
||||
template <typename VectorTransferOpTy>
|
||||
struct VectorTransferRewriter : public RewritePattern {
|
||||
explicit VectorTransferRewriter(MLIRContext *context)
|
||||
: RewritePattern(VectorTransferOpTy::getOperationName(), 1, context) {}
|
||||
|
||||
/// Used for staging the transfer in a local scalar buffer.
|
||||
MemRefType tmpMemRefType() {
|
||||
MemRefType tmpMemRefType(VectorTransferOpTy transfer) const {
|
||||
auto vectorType = transfer.getVectorType();
|
||||
return MemRefType::get(vectorType.getShape(), vectorType.getElementType(),
|
||||
{}, 0);
|
||||
}
|
||||
|
||||
/// View of tmpMemRefType as one vector, used in vector load/store to tmp
|
||||
/// buffer.
|
||||
MemRefType vectorMemRefType() {
|
||||
MemRefType vectorMemRefType(VectorTransferOpTy transfer) const {
|
||||
return MemRefType::get({1}, transfer.getVectorType(), {}, 0);
|
||||
}
|
||||
/// Performs the rewrite.
|
||||
void rewrite();
|
||||
|
||||
private:
|
||||
VectorTransferOpTy transfer;
|
||||
MLFuncLoweringRewriter *rewriter;
|
||||
MLFuncGlobalLoweringState *state;
|
||||
/// Performs the rewrite.
|
||||
PatternMatchResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
|
@ -213,12 +209,6 @@ clip(VectorTransferOpTy transfer, edsc::MemRefView &view,
|
|||
return clippedScalarAccessExprs;
|
||||
}
|
||||
|
||||
template <typename VectorTransferOpTy>
|
||||
VectorTransferRewriter<VectorTransferOpTy>::VectorTransferRewriter(
|
||||
VectorTransferOpTy transfer, MLFuncLoweringRewriter *rewriter,
|
||||
MLFuncGlobalLoweringState *state)
|
||||
: transfer(transfer), rewriter(rewriter), state(state){};
|
||||
|
||||
/// Lowers VectorTransferReadOp into a combination of:
|
||||
/// 1. local memory allocation;
|
||||
/// 2. perfect loop nest over:
|
||||
|
@ -260,13 +250,20 @@ VectorTransferRewriter<VectorTransferOpTy>::VectorTransferRewriter(
|
|||
///
|
||||
/// TODO(ntv): implement alternatives to clipping.
|
||||
/// TODO(ntv): support non-data-parallel operations.
|
||||
template <> void VectorTransferRewriter<VectorTransferReadOp>::rewrite() {
|
||||
|
||||
/// Performs the rewrite.
|
||||
template <>
|
||||
PatternMatchResult
|
||||
VectorTransferRewriter<VectorTransferReadOp>::matchAndRewrite(
|
||||
Operation *op, PatternRewriter &rewriter) const {
|
||||
using namespace mlir::edsc;
|
||||
using namespace mlir::edsc::op;
|
||||
using namespace mlir::edsc::intrinsics;
|
||||
|
||||
VectorTransferReadOp transfer = op->cast<VectorTransferReadOp>();
|
||||
|
||||
// 1. Setup all the captures.
|
||||
ScopedContext scope(FuncBuilder(transfer.getOperation()), transfer.getLoc());
|
||||
ScopedContext scope(FuncBuilder(op), transfer.getLoc());
|
||||
IndexedValue remote(transfer.getMemRef());
|
||||
MemRefView view(transfer.getMemRef());
|
||||
VectorView vectorView(transfer.getVector());
|
||||
|
@ -281,9 +278,9 @@ template <> void VectorTransferRewriter<VectorTransferReadOp>::rewrite() {
|
|||
auto steps = vectorView.getSteps();
|
||||
|
||||
// 2. Emit alloc-copy-load-dealloc.
|
||||
ValueHandle tmp = alloc(tmpMemRefType());
|
||||
ValueHandle tmp = alloc(tmpMemRefType(transfer));
|
||||
IndexedValue local(tmp);
|
||||
ValueHandle vec = vector_type_cast(tmp, vectorMemRefType());
|
||||
ValueHandle vec = vector_type_cast(tmp, vectorMemRefType(transfer));
|
||||
LoopNestBuilder(pivs, lbs, ubs, steps)({
|
||||
// Computes clippedScalarAccessExprs in the loop nest scope (ivs exist).
|
||||
local(ivs) = remote(clip(transfer, view, ivs)),
|
||||
|
@ -292,8 +289,8 @@ template <> void VectorTransferRewriter<VectorTransferReadOp>::rewrite() {
|
|||
(dealloc(tmp)); // vexing parse
|
||||
|
||||
// 3. Propagate.
|
||||
transfer.replaceAllUsesWith(vectorValue.getValue());
|
||||
transfer.erase();
|
||||
rewriter.replaceOp(op, vectorValue.getValue());
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
/// Lowers VectorTransferWriteOp into a combination of:
|
||||
|
@ -314,13 +311,18 @@ template <> void VectorTransferRewriter<VectorTransferReadOp>::rewrite() {
|
|||
///
|
||||
/// TODO(ntv): implement alternatives to clipping.
|
||||
/// TODO(ntv): support non-data-parallel operations.
|
||||
template <> void VectorTransferRewriter<VectorTransferWriteOp>::rewrite() {
|
||||
template <>
|
||||
PatternMatchResult
|
||||
VectorTransferRewriter<VectorTransferWriteOp>::matchAndRewrite(
|
||||
Operation *op, PatternRewriter &rewriter) const {
|
||||
using namespace mlir::edsc;
|
||||
using namespace mlir::edsc::op;
|
||||
using namespace mlir::edsc::intrinsics;
|
||||
|
||||
VectorTransferWriteOp transfer = op->cast<VectorTransferWriteOp>();
|
||||
|
||||
// 1. Setup all the captures.
|
||||
ScopedContext scope(FuncBuilder(transfer.getOperation()), transfer.getLoc());
|
||||
ScopedContext scope(FuncBuilder(op), transfer.getLoc());
|
||||
IndexedValue remote(transfer.getMemRef());
|
||||
MemRefView view(transfer.getMemRef());
|
||||
ValueHandle vectorValue(transfer.getVector());
|
||||
|
@ -336,9 +338,9 @@ template <> void VectorTransferRewriter<VectorTransferWriteOp>::rewrite() {
|
|||
auto steps = vectorView.getSteps();
|
||||
|
||||
// 2. Emit alloc-store-copy-dealloc.
|
||||
ValueHandle tmp = alloc(tmpMemRefType());
|
||||
ValueHandle tmp = alloc(tmpMemRefType(transfer));
|
||||
IndexedValue local(tmp);
|
||||
ValueHandle vec = vector_type_cast(tmp, vectorMemRefType());
|
||||
ValueHandle vec = vector_type_cast(tmp, vectorMemRefType(transfer));
|
||||
store(vectorValue, vec, {constant_index(0)});
|
||||
LoopNestBuilder(pivs, lbs, ubs, steps)({
|
||||
// Computes clippedScalarAccessExprs in the loop nest scope (ivs exist).
|
||||
|
@ -346,36 +348,23 @@ template <> void VectorTransferRewriter<VectorTransferWriteOp>::rewrite() {
|
|||
});
|
||||
(dealloc(tmp)); // vexing parse...
|
||||
|
||||
transfer.erase();
|
||||
rewriter.replaceOp(op, llvm::None);
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
namespace {
|
||||
template <typename VectorTransferOpTy>
|
||||
class VectorTransferExpander : public MLLoweringPattern {
|
||||
public:
|
||||
explicit VectorTransferExpander(MLIRContext *context)
|
||||
: MLLoweringPattern(VectorTransferOpTy::getOperationName(), 1, context) {}
|
||||
|
||||
PatternMatchResult match(Operation *op) const override {
|
||||
if (m_Op<VectorTransferOpTy>().match(op))
|
||||
return matchSuccess();
|
||||
return matchFailure();
|
||||
}
|
||||
void rewriteOpInst(Operation *op, MLFuncGlobalLoweringState *funcWiseState,
|
||||
std::unique_ptr<PatternState> opState,
|
||||
MLFuncLoweringRewriter *rewriter) const override {
|
||||
VectorTransferRewriter<VectorTransferOpTy>(
|
||||
op->dyn_cast<VectorTransferOpTy>(), rewriter, funcWiseState)
|
||||
.rewrite();
|
||||
}
|
||||
};
|
||||
|
||||
struct LowerVectorTransfersPass
|
||||
: public FunctionPass<LowerVectorTransfersPass> {
|
||||
void runOnFunction() {
|
||||
auto &f = getFunction();
|
||||
applyMLPatternsGreedily<VectorTransferExpander<VectorTransferReadOp>,
|
||||
VectorTransferExpander<VectorTransferWriteOp>>(&f);
|
||||
OwningRewritePatternList patterns;
|
||||
auto *context = &getContext();
|
||||
patterns.push_back(
|
||||
llvm::make_unique<VectorTransferRewriter<VectorTransferReadOp>>(
|
||||
context));
|
||||
patterns.push_back(
|
||||
llvm::make_unique<VectorTransferRewriter<VectorTransferWriteOp>>(
|
||||
context));
|
||||
applyPatternsGreedily(getFunction(), std::move(patterns));
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -388,5 +377,3 @@ FunctionPassBase *mlir::createLowerVectorTransfersPass() {
|
|||
static PassRegistration<LowerVectorTransfersPass>
|
||||
pass("lower-vector-transfers", "Materializes vector transfer ops to a "
|
||||
"proper abstraction for the hardware");
|
||||
|
||||
#undef DEBUG_TYPE
|
||||
|
|
|
@ -131,8 +131,8 @@ func @materialize_read(%M: index, %N: index, %O: index, %P: index) {
|
|||
|
||||
// CHECK-LABEL:func @materialize_write(%arg0: index, %arg1: index, %arg2: index, %arg3: index) {
|
||||
func @materialize_write(%M: index, %N: index, %O: index, %P: index) {
|
||||
// CHECK-NEXT: %0 = alloc(%arg0, %arg1, %arg2, %arg3) : memref<?x?x?x?xf32>
|
||||
// CHECK-NEXT: %cst = constant splat<vector<5x4x3xf32>, 1.000000e+00> : vector<5x4x3xf32>
|
||||
// CHECK-NEXT: %0 = alloc(%arg0, %arg1, %arg2, %arg3) : memref<?x?x?x?xf32>
|
||||
// CHECK-NEXT: affine.for %[[I0:.*]] = 0 to %arg0 step 3 {
|
||||
// CHECK-NEXT: affine.for %[[I1:.*]] = 0 to %arg1 step 4 {
|
||||
// CHECK-NEXT: affine.for %[[I2:.*]] = 0 to %arg2 {
|
||||
|
|
Loading…
Reference in New Issue