[mlir] Add SubViewOp::getOrCreateRanges and fix folding pattern

The existing implementation of SubViewOp::getRanges relies on all
offsets/sizes/strides to be dynamic values and does not work in
combination with canonicalization. This revision adds a
SubViewOp::getOrCreateRanges to create the missing constants in the
canonicalized case.

This allows reactivating the fused pass with staged pattern
applications.

However another issue surfaces that the SubViewOp verifier is now too
strict to allow folding. The existing folding pattern is turned into a
canonicalization pattern which rewrites memref_cast + subview into
subview + memref_cast.

The transform-patterns-matmul-to-vector can then be reactivated.

Differential Revision: https://reviews.llvm.org/D79759
This commit is contained in:
Nicolas Vasilache 2020-05-12 22:21:36 -04:00
parent 195de442da
commit e0b99a5de4
6 changed files with 119 additions and 45 deletions

View File

@ -2676,8 +2676,18 @@ def SubViewOp : Std_Op<"subview", [
struct Range { struct Range {
Value offset, size, stride; Value offset, size, stride;
}; };
// TODO: retire `getRanges`. /// Return the list of SubViewOp::Range (i.e. offset, size, stride). Each
SmallVector<Range, 8> getRanges(); /// Range entry contains either the dynamic value or a ConstantIndexOp
/// constructed with `b` at location `loc`.
SmallVector<Range, 8> getOrCreateRanges(OpBuilder &b, Location loc);
/// A subview result type can be fully inferred from the source type and the
/// static representation of offsets, sizes and strides. Special sentinels
/// encode the dynamic case.
static Type inferSubViewResultType(MemRefType sourceMemRefType,
ArrayRef<int64_t> staticOffsets,
ArrayRef<int64_t> staticSizes,
ArrayRef<int64_t> staticStrides);
/// Return the rank of the result MemRefType. /// Return the rank of the result MemRefType.
unsigned getRank() { return getType().getRank(); } unsigned getRank() { return getType().getRank(); }
@ -2750,7 +2760,6 @@ def SubViewOp : Std_Op<"subview", [
}]; }];
let hasCanonicalizer = 1; let hasCanonicalizer = 1;
let hasFolder = 1;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -184,15 +184,16 @@ static LinalgOp fuse(Value producedView, LinalgOp producer, LinalgOp consumer,
unsigned nWin = producer.getNumWindowLoops(); unsigned nWin = producer.getNumWindowLoops();
SmallVector<SubViewOp::Range, 8> loopRanges(nPar + nRed + nWin); SmallVector<SubViewOp::Range, 8> loopRanges(nPar + nRed + nWin);
OpBuilder b(consumer.getOperation());
auto loc = consumer.getLoc();
// Iterate over dimensions identified by the producer map for `producerIdx`. // Iterate over dimensions identified by the producer map for `producerIdx`.
// This defines a subset of the loop ranges that we need to complete later. // This defines a subset of the loop ranges that we need to complete later.
for (auto en : llvm::enumerate(producerMap.getResults())) { for (auto en : llvm::enumerate(producerMap.getResults())) {
unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition(); unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition();
loopRanges[posInProducerLoop] = subView.getRanges()[en.index()]; loopRanges[posInProducerLoop] =
subView.getOrCreateRanges(b, loc)[en.index()];
} }
OpBuilder b(consumer.getOperation());
auto loc = consumer.getLoc();
// Iterate over all dimensions. For the dimensions not identified by the // Iterate over all dimensions. For the dimensions not identified by the
// producer map for `producerIdx`, we need to explicitly compute the view that // producer map for `producerIdx`, we need to explicitly compute the view that
// defines the loop ranges using the `producer`. // defines the loop ranges using the `producer`.

View File

@ -153,7 +153,7 @@ static PromotionInfo promoteSubviewAsNewBuffer(OpBuilder &b, Location loc,
SmallVector<Value, 8> fullSizes, partialSizes; SmallVector<Value, 8> fullSizes, partialSizes;
fullSizes.reserve(rank); fullSizes.reserve(rank);
partialSizes.reserve(rank); partialSizes.reserve(rank);
for (auto en : llvm::enumerate(subView.getRanges())) { for (auto en : llvm::enumerate(subView.getOrCreateRanges(b, loc))) {
auto rank = en.index(); auto rank = en.index();
auto rangeValue = en.value(); auto rangeValue = en.value();
// Try to extract a tight constant. // Try to extract a tight constant.
@ -169,7 +169,7 @@ static PromotionInfo promoteSubviewAsNewBuffer(OpBuilder &b, Location loc,
dynamicBuffers, folder, alignment); dynamicBuffers, folder, alignment);
auto fullLocalView = folded_std_view( auto fullLocalView = folded_std_view(
folder, MemRefType::get(dynSizes, viewType.getElementType()), buffer, folder, MemRefType::get(dynSizes, viewType.getElementType()), buffer,
folded_std_constant_index(folder, 0), fullSizes); zero, fullSizes);
SmallVector<Value, 4> zeros(fullSizes.size(), zero); SmallVector<Value, 4> zeros(fullSizes.size(), zero);
SmallVector<Value, 4> ones(fullSizes.size(), one); SmallVector<Value, 4> ones(fullSizes.size(), one);
auto partialLocalView = auto partialLocalView =

View File

@ -2275,7 +2275,7 @@ Wrapper operator*(Wrapper a, int64_t b) {
/// A subview result type can be fully inferred from the source type and the /// A subview result type can be fully inferred from the source type and the
/// static representation of offsets, sizes and strides. Special sentinels /// static representation of offsets, sizes and strides. Special sentinels
/// encode the dynamic case. /// encode the dynamic case.
static Type inferSubViewResultType(MemRefType sourceMemRefType, Type SubViewOp::inferSubViewResultType(MemRefType sourceMemRefType,
ArrayRef<int64_t> staticOffsets, ArrayRef<int64_t> staticOffsets,
ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticSizes,
ArrayRef<int64_t> staticStrides) { ArrayRef<int64_t> staticStrides) {
@ -2474,7 +2474,7 @@ static LogicalResult verify(SubViewOp op) {
return failure(); return failure();
// Verify result type against inferred type. // Verify result type against inferred type.
auto expectedType = inferSubViewResultType( auto expectedType = SubViewOp::inferSubViewResultType(
op.getBaseMemRefType(), extractFromI64ArrayAttr(op.static_offsets()), op.getBaseMemRefType(), extractFromI64ArrayAttr(op.static_offsets()),
extractFromI64ArrayAttr(op.static_sizes()), extractFromI64ArrayAttr(op.static_sizes()),
extractFromI64ArrayAttr(op.static_strides())); extractFromI64ArrayAttr(op.static_strides()));
@ -2489,16 +2489,6 @@ raw_ostream &mlir::operator<<(raw_ostream &os, SubViewOp::Range &range) {
<< range.stride; << range.stride;
} }
SmallVector<SubViewOp::Range, 8> SubViewOp::getRanges() {
SmallVector<Range, 8> res;
unsigned rank = getType().getRank();
res.reserve(rank);
for (unsigned i = 0; i < rank; ++i)
res.emplace_back(Range{*(offsets().begin() + i), *(sizes().begin() + i),
*(strides().begin() + i)});
return res;
}
static unsigned getNumDynamicEntriesUpToIdx( static unsigned getNumDynamicEntriesUpToIdx(
ArrayAttr attr, llvm::function_ref<bool(int64_t)> isDynamic, unsigned idx) { ArrayAttr attr, llvm::function_ref<bool(int64_t)> isDynamic, unsigned idx) {
return std::count_if(attr.getValue().begin(), attr.getValue().begin() + idx, return std::count_if(attr.getValue().begin(), attr.getValue().begin() + idx,
@ -2540,6 +2530,29 @@ unsigned SubViewOp::getIndexOfDynamicStride(unsigned idx) {
return 1 + offsets().size() + sizes().size() + numDynamic; return 1 + offsets().size() + sizes().size() + numDynamic;
} }
/// Return the list of SubViewOp::Range (i.e. offset, size, stride). Each Range
/// entry contains either the dynamic value or a ConstantIndexOp constructed
/// with `b` at location `loc`.
SmallVector<SubViewOp::Range, 8> SubViewOp::getOrCreateRanges(OpBuilder &b,
Location loc) {
SmallVector<Range, 8> res;
unsigned rank = getType().getRank();
res.reserve(rank);
for (unsigned idx = 0; idx < rank; ++idx) {
auto offset = isDynamicOffset(idx)
? getDynamicOffset(idx)
: b.create<ConstantIndexOp>(loc, getStaticOffset(idx));
auto size = isDynamicSize(idx)
? getDynamicSize(idx)
: b.create<ConstantIndexOp>(loc, getStaticSize(idx));
auto stride = isDynamicStride(idx)
? getDynamicStride(idx)
: b.create<ConstantIndexOp>(loc, getStaticStride(idx));
res.emplace_back(Range{offset, size, stride});
}
return res;
}
LogicalResult LogicalResult
SubViewOp::getStaticStrides(SmallVectorImpl<int64_t> &staticStrides) { SubViewOp::getStaticStrides(SmallVectorImpl<int64_t> &staticStrides) {
if (!strides().empty()) if (!strides().empty())
@ -2583,7 +2596,8 @@ void canonicalizeSubViewPart(SmallVectorImpl<Value> &values,
} }
/// Pattern to rewrite a subview op with constant arguments. /// Pattern to rewrite a subview op with constant arguments.
class SubViewOpFolder final : public OpRewritePattern<SubViewOp> { class SubViewOpConstantArgumentFolder final
: public OpRewritePattern<SubViewOp> {
public: public:
using OpRewritePattern<SubViewOp>::OpRewritePattern; using OpRewritePattern<SubViewOp>::OpRewritePattern;
@ -2718,27 +2732,63 @@ bool mlir::canFoldIntoConsumerOp(MemRefCastOp castOp) {
return true; return true;
} }
OpFoldResult SubViewOp::fold(ArrayRef<Attribute>) { /// Pattern to rewrite a subview op with MemRefCast arguments.
auto folds = [](Operation *op) { /// This essentially pushes memref_cast past its consuming subview when
bool folded = false; /// `canFoldIntoConsumerOp` is true.
for (OpOperand &operand : op->getOpOperands()) { ///
auto castOp = operand.get().getDefiningOp<MemRefCastOp>(); /// Example:
if (castOp && canFoldIntoConsumerOp(castOp)) { /// ```
operand.set(castOp.getOperand()); /// %0 = memref_cast %V : memref<16x16xf32> to memref<?x?xf32>
folded = true; /// %1 = subview %0[0, 0][3, 4][1, 1] :
} /// memref<?x?xf32> to memref<3x4xf32, offset:?, strides:[?, 1]>
} /// ```
return folded ? success() : failure(); /// is rewritten into:
}; /// ```
/// %0 = subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]>
/// %1 = memref_cast %0: memref<3x4xf32, offset:0, strides:[16, 1]> to
/// memref<3x4xf32, offset:?, strides:[?, 1]>
/// ```
class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
public:
using OpRewritePattern<SubViewOp>::OpRewritePattern;
if (succeeded(folds(*this))) LogicalResult matchAndRewrite(SubViewOp subViewOp,
return getResult(); PatternRewriter &rewriter) const override {
return {}; // Any constant operand, just return to let SubViewOpConstantFolder kick in.
if (llvm::any_of(subViewOp.getOperands(), [](Value operand) {
return matchPattern(operand, m_ConstantIndex());
}))
return failure();
auto castOp = subViewOp.source().getDefiningOp<MemRefCastOp>();
if (!castOp)
return failure();
if (!canFoldIntoConsumerOp(castOp))
return failure();
/// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on
/// the cast source operand type and the SubViewOp static information. This
/// is the resulting type if the MemRefCastOp were folded.
Type resultType = SubViewOp::inferSubViewResultType(
castOp.source().getType().cast<MemRefType>(),
extractFromI64ArrayAttr(subViewOp.static_offsets()),
extractFromI64ArrayAttr(subViewOp.static_sizes()),
extractFromI64ArrayAttr(subViewOp.static_strides()));
Value newSubView = rewriter.create<SubViewOp>(
subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(),
subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(),
subViewOp.static_sizes(), subViewOp.static_strides());
rewriter.replaceOpWithNewOp<MemRefCastOp>(subViewOp, subViewOp.getType(),
newSubView);
return success();
} }
};
void SubViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results, void SubViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) { MLIRContext *context) {
results.insert<SubViewOpFolder>(context); results.insert<SubViewOpConstantArgumentFolder, SubViewOpMemRefCastFolder>(
context);
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -1,7 +1,5 @@
// TODO: this needs a fix to land before being reactivated. // RUN: mlir-opt %s -test-linalg-transform-patterns=test-matmul-to-vector-patterns-tile-1d | FileCheck %s
// RUN: ls // RUN: mlir-opt %s -test-linalg-transform-patterns=test-matmul-to-vector-patterns-tile-2d | FileCheck %s
// R_UN: mlir-opt %s -test-linalg-transform-patterns=test-matmul-to-vector-patterns-tile-1d | FileCheck %s
// R_UN: mlir-opt %s -test-linalg-transform-patterns=test-matmul-to-vector-patterns-tile-2d | FileCheck %s
func @matmul(%A: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>, func @matmul(%A: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
%B: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>, %B: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,

View File

@ -941,3 +941,19 @@ func @memref_cast_folding_subview(%arg0: memref<4x5xf32>, %i: index) -> (memref<
return %1: memref<?x?xf32, offset:? , strides: [?, ?]> return %1: memref<?x?xf32, offset:? , strides: [?, ?]>
} }
// -----
// CHECK-DAG: #[[map0:.*]] = affine_map<(d0, d1) -> (d0 * 16 + d1)>
// CHECK-DAG: #[[map1:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
// CHECK-LABEL: func @memref_cast_folding_subview_static(
func @memref_cast_folding_subview_static(%V: memref<16x16xf32>, %a: index, %b: index)
-> memref<3x4xf32, offset:?, strides:[?, 1]>
{
%0 = memref_cast %V : memref<16x16xf32> to memref<?x?xf32>
%1 = subview %0[0, 0][3, 4][1, 1] : memref<?x?xf32> to memref<3x4xf32, offset:?, strides:[?, 1]>
// CHECK: subview{{.*}}: memref<16x16xf32> to memref<3x4xf32, #[[map0]]>
// CHECK: memref_cast{{.*}}: memref<3x4xf32, #[[map0]]> to memref<3x4xf32, #[[map1]]>
return %1: memref<3x4xf32, offset:?, strides:[?, 1]>
}