forked from OSchip/llvm-project
[mlir] Add SubViewOp::getOrCreateRanges and fix folding pattern
The existing implementation of SubViewOp::getRanges relies on all offsets/sizes/strides to be dynamic values and does not work in combination with canonicalization. This revision adds a SubViewOp::getOrCreateRanges to create the missing constants in the canonicalized case. This allows reactivating the fused pass with staged pattern applications. However another issue surfaces that the SubViewOp verifier is now too strict to allow folding. The existing folding pattern is turned into a canonicalization pattern which rewrites memref_cast + subview into subview + memref_cast. The transform-patterns-matmul-to-vector can then be reactivated. Differential Revision: https://reviews.llvm.org/D79759
This commit is contained in:
parent
195de442da
commit
e0b99a5de4
|
@ -2676,8 +2676,18 @@ def SubViewOp : Std_Op<"subview", [
|
|||
struct Range {
|
||||
Value offset, size, stride;
|
||||
};
|
||||
// TODO: retire `getRanges`.
|
||||
SmallVector<Range, 8> getRanges();
|
||||
/// Return the list of SubViewOp::Range (i.e. offset, size, stride). Each
|
||||
/// Range entry contains either the dynamic value or a ConstantIndexOp
|
||||
/// constructed with `b` at location `loc`.
|
||||
SmallVector<Range, 8> getOrCreateRanges(OpBuilder &b, Location loc);
|
||||
|
||||
/// A subview result type can be fully inferred from the source type and the
|
||||
/// static representation of offsets, sizes and strides. Special sentinels
|
||||
/// encode the dynamic case.
|
||||
static Type inferSubViewResultType(MemRefType sourceMemRefType,
|
||||
ArrayRef<int64_t> staticOffsets,
|
||||
ArrayRef<int64_t> staticSizes,
|
||||
ArrayRef<int64_t> staticStrides);
|
||||
|
||||
/// Return the rank of the result MemRefType.
|
||||
unsigned getRank() { return getType().getRank(); }
|
||||
|
@ -2750,7 +2760,6 @@ def SubViewOp : Std_Op<"subview", [
|
|||
}];
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -184,15 +184,16 @@ static LinalgOp fuse(Value producedView, LinalgOp producer, LinalgOp consumer,
|
|||
unsigned nWin = producer.getNumWindowLoops();
|
||||
SmallVector<SubViewOp::Range, 8> loopRanges(nPar + nRed + nWin);
|
||||
|
||||
OpBuilder b(consumer.getOperation());
|
||||
auto loc = consumer.getLoc();
|
||||
// Iterate over dimensions identified by the producer map for `producerIdx`.
|
||||
// This defines a subset of the loop ranges that we need to complete later.
|
||||
for (auto en : llvm::enumerate(producerMap.getResults())) {
|
||||
unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition();
|
||||
loopRanges[posInProducerLoop] = subView.getRanges()[en.index()];
|
||||
loopRanges[posInProducerLoop] =
|
||||
subView.getOrCreateRanges(b, loc)[en.index()];
|
||||
}
|
||||
|
||||
OpBuilder b(consumer.getOperation());
|
||||
auto loc = consumer.getLoc();
|
||||
// Iterate over all dimensions. For the dimensions not identified by the
|
||||
// producer map for `producerIdx`, we need to explicitly compute the view that
|
||||
// defines the loop ranges using the `producer`.
|
||||
|
|
|
@ -153,7 +153,7 @@ static PromotionInfo promoteSubviewAsNewBuffer(OpBuilder &b, Location loc,
|
|||
SmallVector<Value, 8> fullSizes, partialSizes;
|
||||
fullSizes.reserve(rank);
|
||||
partialSizes.reserve(rank);
|
||||
for (auto en : llvm::enumerate(subView.getRanges())) {
|
||||
for (auto en : llvm::enumerate(subView.getOrCreateRanges(b, loc))) {
|
||||
auto rank = en.index();
|
||||
auto rangeValue = en.value();
|
||||
// Try to extract a tight constant.
|
||||
|
@ -169,7 +169,7 @@ static PromotionInfo promoteSubviewAsNewBuffer(OpBuilder &b, Location loc,
|
|||
dynamicBuffers, folder, alignment);
|
||||
auto fullLocalView = folded_std_view(
|
||||
folder, MemRefType::get(dynSizes, viewType.getElementType()), buffer,
|
||||
folded_std_constant_index(folder, 0), fullSizes);
|
||||
zero, fullSizes);
|
||||
SmallVector<Value, 4> zeros(fullSizes.size(), zero);
|
||||
SmallVector<Value, 4> ones(fullSizes.size(), one);
|
||||
auto partialLocalView =
|
||||
|
|
|
@ -2275,7 +2275,7 @@ Wrapper operator*(Wrapper a, int64_t b) {
|
|||
/// A subview result type can be fully inferred from the source type and the
|
||||
/// static representation of offsets, sizes and strides. Special sentinels
|
||||
/// encode the dynamic case.
|
||||
static Type inferSubViewResultType(MemRefType sourceMemRefType,
|
||||
Type SubViewOp::inferSubViewResultType(MemRefType sourceMemRefType,
|
||||
ArrayRef<int64_t> staticOffsets,
|
||||
ArrayRef<int64_t> staticSizes,
|
||||
ArrayRef<int64_t> staticStrides) {
|
||||
|
@ -2474,7 +2474,7 @@ static LogicalResult verify(SubViewOp op) {
|
|||
return failure();
|
||||
|
||||
// Verify result type against inferred type.
|
||||
auto expectedType = inferSubViewResultType(
|
||||
auto expectedType = SubViewOp::inferSubViewResultType(
|
||||
op.getBaseMemRefType(), extractFromI64ArrayAttr(op.static_offsets()),
|
||||
extractFromI64ArrayAttr(op.static_sizes()),
|
||||
extractFromI64ArrayAttr(op.static_strides()));
|
||||
|
@ -2489,16 +2489,6 @@ raw_ostream &mlir::operator<<(raw_ostream &os, SubViewOp::Range &range) {
|
|||
<< range.stride;
|
||||
}
|
||||
|
||||
SmallVector<SubViewOp::Range, 8> SubViewOp::getRanges() {
|
||||
SmallVector<Range, 8> res;
|
||||
unsigned rank = getType().getRank();
|
||||
res.reserve(rank);
|
||||
for (unsigned i = 0; i < rank; ++i)
|
||||
res.emplace_back(Range{*(offsets().begin() + i), *(sizes().begin() + i),
|
||||
*(strides().begin() + i)});
|
||||
return res;
|
||||
}
|
||||
|
||||
static unsigned getNumDynamicEntriesUpToIdx(
|
||||
ArrayAttr attr, llvm::function_ref<bool(int64_t)> isDynamic, unsigned idx) {
|
||||
return std::count_if(attr.getValue().begin(), attr.getValue().begin() + idx,
|
||||
|
@ -2540,6 +2530,29 @@ unsigned SubViewOp::getIndexOfDynamicStride(unsigned idx) {
|
|||
return 1 + offsets().size() + sizes().size() + numDynamic;
|
||||
}
|
||||
|
||||
/// Return the list of SubViewOp::Range (i.e. offset, size, stride). Each Range
|
||||
/// entry contains either the dynamic value or a ConstantIndexOp constructed
|
||||
/// with `b` at location `loc`.
|
||||
SmallVector<SubViewOp::Range, 8> SubViewOp::getOrCreateRanges(OpBuilder &b,
|
||||
Location loc) {
|
||||
SmallVector<Range, 8> res;
|
||||
unsigned rank = getType().getRank();
|
||||
res.reserve(rank);
|
||||
for (unsigned idx = 0; idx < rank; ++idx) {
|
||||
auto offset = isDynamicOffset(idx)
|
||||
? getDynamicOffset(idx)
|
||||
: b.create<ConstantIndexOp>(loc, getStaticOffset(idx));
|
||||
auto size = isDynamicSize(idx)
|
||||
? getDynamicSize(idx)
|
||||
: b.create<ConstantIndexOp>(loc, getStaticSize(idx));
|
||||
auto stride = isDynamicStride(idx)
|
||||
? getDynamicStride(idx)
|
||||
: b.create<ConstantIndexOp>(loc, getStaticStride(idx));
|
||||
res.emplace_back(Range{offset, size, stride});
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
SubViewOp::getStaticStrides(SmallVectorImpl<int64_t> &staticStrides) {
|
||||
if (!strides().empty())
|
||||
|
@ -2583,7 +2596,8 @@ void canonicalizeSubViewPart(SmallVectorImpl<Value> &values,
|
|||
}
|
||||
|
||||
/// Pattern to rewrite a subview op with constant arguments.
|
||||
class SubViewOpFolder final : public OpRewritePattern<SubViewOp> {
|
||||
class SubViewOpConstantArgumentFolder final
|
||||
: public OpRewritePattern<SubViewOp> {
|
||||
public:
|
||||
using OpRewritePattern<SubViewOp>::OpRewritePattern;
|
||||
|
||||
|
@ -2718,27 +2732,63 @@ bool mlir::canFoldIntoConsumerOp(MemRefCastOp castOp) {
|
|||
return true;
|
||||
}
|
||||
|
||||
OpFoldResult SubViewOp::fold(ArrayRef<Attribute>) {
|
||||
auto folds = [](Operation *op) {
|
||||
bool folded = false;
|
||||
for (OpOperand &operand : op->getOpOperands()) {
|
||||
auto castOp = operand.get().getDefiningOp<MemRefCastOp>();
|
||||
if (castOp && canFoldIntoConsumerOp(castOp)) {
|
||||
operand.set(castOp.getOperand());
|
||||
folded = true;
|
||||
}
|
||||
}
|
||||
return folded ? success() : failure();
|
||||
};
|
||||
/// Pattern to rewrite a subview op with MemRefCast arguments.
|
||||
/// This essentially pushes memref_cast past its consuming subview when
|
||||
/// `canFoldIntoConsumerOp` is true.
|
||||
///
|
||||
/// Example:
|
||||
/// ```
|
||||
/// %0 = memref_cast %V : memref<16x16xf32> to memref<?x?xf32>
|
||||
/// %1 = subview %0[0, 0][3, 4][1, 1] :
|
||||
/// memref<?x?xf32> to memref<3x4xf32, offset:?, strides:[?, 1]>
|
||||
/// ```
|
||||
/// is rewritten into:
|
||||
/// ```
|
||||
/// %0 = subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]>
|
||||
/// %1 = memref_cast %0: memref<3x4xf32, offset:0, strides:[16, 1]> to
|
||||
/// memref<3x4xf32, offset:?, strides:[?, 1]>
|
||||
/// ```
|
||||
class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
|
||||
public:
|
||||
using OpRewritePattern<SubViewOp>::OpRewritePattern;
|
||||
|
||||
if (succeeded(folds(*this)))
|
||||
return getResult();
|
||||
return {};
|
||||
LogicalResult matchAndRewrite(SubViewOp subViewOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Any constant operand, just return to let SubViewOpConstantFolder kick in.
|
||||
if (llvm::any_of(subViewOp.getOperands(), [](Value operand) {
|
||||
return matchPattern(operand, m_ConstantIndex());
|
||||
}))
|
||||
return failure();
|
||||
|
||||
auto castOp = subViewOp.source().getDefiningOp<MemRefCastOp>();
|
||||
if (!castOp)
|
||||
return failure();
|
||||
|
||||
if (!canFoldIntoConsumerOp(castOp))
|
||||
return failure();
|
||||
|
||||
/// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on
|
||||
/// the cast source operand type and the SubViewOp static information. This
|
||||
/// is the resulting type if the MemRefCastOp were folded.
|
||||
Type resultType = SubViewOp::inferSubViewResultType(
|
||||
castOp.source().getType().cast<MemRefType>(),
|
||||
extractFromI64ArrayAttr(subViewOp.static_offsets()),
|
||||
extractFromI64ArrayAttr(subViewOp.static_sizes()),
|
||||
extractFromI64ArrayAttr(subViewOp.static_strides()));
|
||||
Value newSubView = rewriter.create<SubViewOp>(
|
||||
subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(),
|
||||
subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(),
|
||||
subViewOp.static_sizes(), subViewOp.static_strides());
|
||||
rewriter.replaceOpWithNewOp<MemRefCastOp>(subViewOp, subViewOp.getType(),
|
||||
newSubView);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void SubViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
results.insert<SubViewOpFolder>(context);
|
||||
results.insert<SubViewOpConstantArgumentFolder, SubViewOpMemRefCastFolder>(
|
||||
context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -1,7 +1,5 @@
|
|||
// TODO: this needs a fix to land before being reactivated.
|
||||
// RUN: ls
|
||||
// R_UN: mlir-opt %s -test-linalg-transform-patterns=test-matmul-to-vector-patterns-tile-1d | FileCheck %s
|
||||
// R_UN: mlir-opt %s -test-linalg-transform-patterns=test-matmul-to-vector-patterns-tile-2d | FileCheck %s
|
||||
// RUN: mlir-opt %s -test-linalg-transform-patterns=test-matmul-to-vector-patterns-tile-1d | FileCheck %s
|
||||
// RUN: mlir-opt %s -test-linalg-transform-patterns=test-matmul-to-vector-patterns-tile-2d | FileCheck %s
|
||||
|
||||
func @matmul(%A: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
|
||||
%B: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
|
||||
|
|
|
@ -941,3 +941,19 @@ func @memref_cast_folding_subview(%arg0: memref<4x5xf32>, %i: index) -> (memref<
|
|||
return %1: memref<?x?xf32, offset:? , strides: [?, ?]>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-DAG: #[[map0:.*]] = affine_map<(d0, d1) -> (d0 * 16 + d1)>
|
||||
// CHECK-DAG: #[[map1:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
|
||||
|
||||
// CHECK-LABEL: func @memref_cast_folding_subview_static(
|
||||
func @memref_cast_folding_subview_static(%V: memref<16x16xf32>, %a: index, %b: index)
|
||||
-> memref<3x4xf32, offset:?, strides:[?, 1]>
|
||||
{
|
||||
%0 = memref_cast %V : memref<16x16xf32> to memref<?x?xf32>
|
||||
%1 = subview %0[0, 0][3, 4][1, 1] : memref<?x?xf32> to memref<3x4xf32, offset:?, strides:[?, 1]>
|
||||
|
||||
// CHECK: subview{{.*}}: memref<16x16xf32> to memref<3x4xf32, #[[map0]]>
|
||||
// CHECK: memref_cast{{.*}}: memref<3x4xf32, #[[map0]]> to memref<3x4xf32, #[[map1]]>
|
||||
return %1: memref<3x4xf32, offset:?, strides:[?, 1]>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue