[mlir] Add SubViewOp::getOrCreateRanges and fix folding pattern

The existing implementation of SubViewOp::getRanges relies on all
offsets/sizes/strides to be dynamic values and does not work in
combination with canonicalization. This revision adds a
SubViewOp::getOrCreateRanges to create the missing constants in the
canonicalized case.

This allows reactivating the fused pass with staged pattern
applications.

However another issue surfaces that the SubViewOp verifier is now too
strict to allow folding. The existing folding pattern is turned into a
canonicalization pattern which rewrites memref_cast + subview into
subview + memref_cast.

The transform-patterns-matmul-to-vector can then be reactivated.

Differential Revision: https://reviews.llvm.org/D79759
This commit is contained in:
Nicolas Vasilache 2020-05-12 22:21:36 -04:00
parent 195de442da
commit e0b99a5de4
6 changed files with 119 additions and 45 deletions

View File

@ -2676,8 +2676,18 @@ def SubViewOp : Std_Op<"subview", [
struct Range {
Value offset, size, stride;
};
// TODO: retire `getRanges`.
SmallVector<Range, 8> getRanges();
/// Return the list of SubViewOp::Range (i.e. offset, size, stride). Each
/// Range entry contains either the dynamic value or a ConstantIndexOp
/// constructed with `b` at location `loc`.
SmallVector<Range, 8> getOrCreateRanges(OpBuilder &b, Location loc);
/// A subview result type can be fully inferred from the source type and the
/// static representation of offsets, sizes and strides. Special sentinels
/// encode the dynamic case.
static Type inferSubViewResultType(MemRefType sourceMemRefType,
ArrayRef<int64_t> staticOffsets,
ArrayRef<int64_t> staticSizes,
ArrayRef<int64_t> staticStrides);
/// Return the rank of the result MemRefType.
unsigned getRank() { return getType().getRank(); }
@ -2750,7 +2760,6 @@ def SubViewOp : Std_Op<"subview", [
}];
let hasCanonicalizer = 1;
let hasFolder = 1;
}
//===----------------------------------------------------------------------===//

View File

@ -184,15 +184,16 @@ static LinalgOp fuse(Value producedView, LinalgOp producer, LinalgOp consumer,
unsigned nWin = producer.getNumWindowLoops();
SmallVector<SubViewOp::Range, 8> loopRanges(nPar + nRed + nWin);
OpBuilder b(consumer.getOperation());
auto loc = consumer.getLoc();
// Iterate over dimensions identified by the producer map for `producerIdx`.
// This defines a subset of the loop ranges that we need to complete later.
for (auto en : llvm::enumerate(producerMap.getResults())) {
unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition();
loopRanges[posInProducerLoop] = subView.getRanges()[en.index()];
loopRanges[posInProducerLoop] =
subView.getOrCreateRanges(b, loc)[en.index()];
}
OpBuilder b(consumer.getOperation());
auto loc = consumer.getLoc();
// Iterate over all dimensions. For the dimensions not identified by the
// producer map for `producerIdx`, we need to explicitly compute the view that
// defines the loop ranges using the `producer`.

View File

@ -153,7 +153,7 @@ static PromotionInfo promoteSubviewAsNewBuffer(OpBuilder &b, Location loc,
SmallVector<Value, 8> fullSizes, partialSizes;
fullSizes.reserve(rank);
partialSizes.reserve(rank);
for (auto en : llvm::enumerate(subView.getRanges())) {
for (auto en : llvm::enumerate(subView.getOrCreateRanges(b, loc))) {
auto rank = en.index();
auto rangeValue = en.value();
// Try to extract a tight constant.
@ -169,7 +169,7 @@ static PromotionInfo promoteSubviewAsNewBuffer(OpBuilder &b, Location loc,
dynamicBuffers, folder, alignment);
auto fullLocalView = folded_std_view(
folder, MemRefType::get(dynSizes, viewType.getElementType()), buffer,
folded_std_constant_index(folder, 0), fullSizes);
zero, fullSizes);
SmallVector<Value, 4> zeros(fullSizes.size(), zero);
SmallVector<Value, 4> ones(fullSizes.size(), one);
auto partialLocalView =

View File

@ -2275,10 +2275,10 @@ Wrapper operator*(Wrapper a, int64_t b) {
/// A subview result type can be fully inferred from the source type and the
/// static representation of offsets, sizes and strides. Special sentinels
/// encode the dynamic case.
static Type inferSubViewResultType(MemRefType sourceMemRefType,
ArrayRef<int64_t> staticOffsets,
ArrayRef<int64_t> staticSizes,
ArrayRef<int64_t> staticStrides) {
Type SubViewOp::inferSubViewResultType(MemRefType sourceMemRefType,
ArrayRef<int64_t> staticOffsets,
ArrayRef<int64_t> staticSizes,
ArrayRef<int64_t> staticStrides) {
unsigned rank = sourceMemRefType.getRank();
(void)rank;
assert(staticOffsets.size() == rank &&
@ -2474,7 +2474,7 @@ static LogicalResult verify(SubViewOp op) {
return failure();
// Verify result type against inferred type.
auto expectedType = inferSubViewResultType(
auto expectedType = SubViewOp::inferSubViewResultType(
op.getBaseMemRefType(), extractFromI64ArrayAttr(op.static_offsets()),
extractFromI64ArrayAttr(op.static_sizes()),
extractFromI64ArrayAttr(op.static_strides()));
@ -2489,16 +2489,6 @@ raw_ostream &mlir::operator<<(raw_ostream &os, SubViewOp::Range &range) {
<< range.stride;
}
SmallVector<SubViewOp::Range, 8> SubViewOp::getRanges() {
SmallVector<Range, 8> res;
unsigned rank = getType().getRank();
res.reserve(rank);
for (unsigned i = 0; i < rank; ++i)
res.emplace_back(Range{*(offsets().begin() + i), *(sizes().begin() + i),
*(strides().begin() + i)});
return res;
}
static unsigned getNumDynamicEntriesUpToIdx(
ArrayAttr attr, llvm::function_ref<bool(int64_t)> isDynamic, unsigned idx) {
return std::count_if(attr.getValue().begin(), attr.getValue().begin() + idx,
@ -2540,6 +2530,29 @@ unsigned SubViewOp::getIndexOfDynamicStride(unsigned idx) {
return 1 + offsets().size() + sizes().size() + numDynamic;
}
/// Return the list of SubViewOp::Range (i.e. offset, size, stride). Each Range
/// entry contains either the dynamic value or a ConstantIndexOp constructed
/// with `b` at location `loc`.
SmallVector<SubViewOp::Range, 8> SubViewOp::getOrCreateRanges(OpBuilder &b,
Location loc) {
SmallVector<Range, 8> res;
unsigned rank = getType().getRank();
res.reserve(rank);
for (unsigned idx = 0; idx < rank; ++idx) {
auto offset = isDynamicOffset(idx)
? getDynamicOffset(idx)
: b.create<ConstantIndexOp>(loc, getStaticOffset(idx));
auto size = isDynamicSize(idx)
? getDynamicSize(idx)
: b.create<ConstantIndexOp>(loc, getStaticSize(idx));
auto stride = isDynamicStride(idx)
? getDynamicStride(idx)
: b.create<ConstantIndexOp>(loc, getStaticStride(idx));
res.emplace_back(Range{offset, size, stride});
}
return res;
}
LogicalResult
SubViewOp::getStaticStrides(SmallVectorImpl<int64_t> &staticStrides) {
if (!strides().empty())
@ -2583,7 +2596,8 @@ void canonicalizeSubViewPart(SmallVectorImpl<Value> &values,
}
/// Pattern to rewrite a subview op with constant arguments.
class SubViewOpFolder final : public OpRewritePattern<SubViewOp> {
class SubViewOpConstantArgumentFolder final
: public OpRewritePattern<SubViewOp> {
public:
using OpRewritePattern<SubViewOp>::OpRewritePattern;
@ -2718,27 +2732,63 @@ bool mlir::canFoldIntoConsumerOp(MemRefCastOp castOp) {
return true;
}
OpFoldResult SubViewOp::fold(ArrayRef<Attribute>) {
auto folds = [](Operation *op) {
bool folded = false;
for (OpOperand &operand : op->getOpOperands()) {
auto castOp = operand.get().getDefiningOp<MemRefCastOp>();
if (castOp && canFoldIntoConsumerOp(castOp)) {
operand.set(castOp.getOperand());
folded = true;
}
}
return folded ? success() : failure();
};
/// Pattern to rewrite a subview op with MemRefCast arguments.
/// This essentially pushes memref_cast past its consuming subview when
/// `canFoldIntoConsumerOp` is true.
///
/// Example:
/// ```
/// %0 = memref_cast %V : memref<16x16xf32> to memref<?x?xf32>
/// %1 = subview %0[0, 0][3, 4][1, 1] :
/// memref<?x?xf32> to memref<3x4xf32, offset:?, strides:[?, 1]>
/// ```
/// is rewritten into:
/// ```
/// %0 = subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]>
/// %1 = memref_cast %0: memref<3x4xf32, offset:0, strides:[16, 1]> to
/// memref<3x4xf32, offset:?, strides:[?, 1]>
/// ```
class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
public:
using OpRewritePattern<SubViewOp>::OpRewritePattern;
if (succeeded(folds(*this)))
return getResult();
return {};
}
LogicalResult matchAndRewrite(SubViewOp subViewOp,
PatternRewriter &rewriter) const override {
// Any constant operand, just return to let SubViewOpConstantFolder kick in.
if (llvm::any_of(subViewOp.getOperands(), [](Value operand) {
return matchPattern(operand, m_ConstantIndex());
}))
return failure();
auto castOp = subViewOp.source().getDefiningOp<MemRefCastOp>();
if (!castOp)
return failure();
if (!canFoldIntoConsumerOp(castOp))
return failure();
/// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on
/// the cast source operand type and the SubViewOp static information. This
/// is the resulting type if the MemRefCastOp were folded.
Type resultType = SubViewOp::inferSubViewResultType(
castOp.source().getType().cast<MemRefType>(),
extractFromI64ArrayAttr(subViewOp.static_offsets()),
extractFromI64ArrayAttr(subViewOp.static_sizes()),
extractFromI64ArrayAttr(subViewOp.static_strides()));
Value newSubView = rewriter.create<SubViewOp>(
subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(),
subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(),
subViewOp.static_sizes(), subViewOp.static_strides());
rewriter.replaceOpWithNewOp<MemRefCastOp>(subViewOp, subViewOp.getType(),
newSubView);
return success();
}
};
void SubViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<SubViewOpFolder>(context);
results.insert<SubViewOpConstantArgumentFolder, SubViewOpMemRefCastFolder>(
context);
}
//===----------------------------------------------------------------------===//

View File

@ -1,7 +1,5 @@
// TODO: this needs a fix to land before being reactivated.
// RUN: ls
// R_UN: mlir-opt %s -test-linalg-transform-patterns=test-matmul-to-vector-patterns-tile-1d | FileCheck %s
// R_UN: mlir-opt %s -test-linalg-transform-patterns=test-matmul-to-vector-patterns-tile-2d | FileCheck %s
// RUN: mlir-opt %s -test-linalg-transform-patterns=test-matmul-to-vector-patterns-tile-1d | FileCheck %s
// RUN: mlir-opt %s -test-linalg-transform-patterns=test-matmul-to-vector-patterns-tile-2d | FileCheck %s
func @matmul(%A: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
%B: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,

View File

@ -941,3 +941,19 @@ func @memref_cast_folding_subview(%arg0: memref<4x5xf32>, %i: index) -> (memref<
return %1: memref<?x?xf32, offset:? , strides: [?, ?]>
}
// -----
// CHECK-DAG: #[[map0:.*]] = affine_map<(d0, d1) -> (d0 * 16 + d1)>
// CHECK-DAG: #[[map1:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
// CHECK-LABEL: func @memref_cast_folding_subview_static(
func @memref_cast_folding_subview_static(%V: memref<16x16xf32>, %a: index, %b: index)
-> memref<3x4xf32, offset:?, strides:[?, 1]>
{
%0 = memref_cast %V : memref<16x16xf32> to memref<?x?xf32>
%1 = subview %0[0, 0][3, 4][1, 1] : memref<?x?xf32> to memref<3x4xf32, offset:?, strides:[?, 1]>
// CHECK: subview{{.*}}: memref<16x16xf32> to memref<3x4xf32, #[[map0]]>
// CHECK: memref_cast{{.*}}: memref<3x4xf32, #[[map0]]> to memref<3x4xf32, #[[map1]]>
return %1: memref<3x4xf32, offset:?, strides:[?, 1]>
}