[mlir][vector] Add more vector Ops canonicalization

Add canonicalization for BroadcastOp, ExtractStrideSlicesOp and ShapeCastOp

Differential Revision: https://reviews.llvm.org/D93120
This commit is contained in:
Thomas Raoux 2020-12-11 07:08:55 -08:00
parent 1876a2914f
commit 74186880ba
3 changed files with 142 additions and 5 deletions

View File

@ -271,6 +271,7 @@ def Vector_BroadcastOp :
}];
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($vector)";
let hasFolder = 1;
let hasCanonicalizer = 1;
}
def Vector_ShuffleOp :

View File

@ -1110,6 +1110,36 @@ OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
return {};
}
namespace {
// BroadcastOp can only add dimensions or broadcast a dimension from 1 to N. In
// the degenerated case where the broadcast only adds dimensions of size 1 it
// can be replaced by a ShapeCastOp. This canonicalization checks if the total
// number of elements is the same before and after the broadcast to detect if
// the only change in the vector type are new dimensions of size 1.
class BroadcastToShapeCast final : public OpRewritePattern<BroadcastOp> {
public:
using OpRewritePattern<BroadcastOp>::OpRewritePattern;
LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
PatternRewriter &rewriter) const override {
auto srcVecType = broadcastOp.getSourceType().dyn_cast<VectorType>();
if (!srcVecType || broadcastOp.getVectorType().getNumElements() !=
srcVecType.getNumElements())
return failure();
rewriter.replaceOpWithNewOp<ShapeCastOp>(
broadcastOp, broadcastOp.getVectorType(), broadcastOp.source());
return success();
}
};
} // namespace
void BroadcastOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<BroadcastToShapeCast>(context);
}
//===----------------------------------------------------------------------===//
// ShuffleOp
//===----------------------------------------------------------------------===//
@ -1768,7 +1798,8 @@ void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
namespace {
// Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) -> ConstantMaskOp.
// Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to
// ConstantMaskOp.
class StridedSliceConstantMaskFolder final
: public OpRewritePattern<ExtractStridedSliceOp> {
public:
@ -1847,14 +1878,70 @@ public:
}
};
// Helper that returns a subset of `arrayAttr` as a vector of int64_t.
static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
unsigned dropFront = 0,
unsigned dropBack = 0) {
assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
auto range = arrayAttr.getAsRange<IntegerAttr>();
SmallVector<int64_t, 4> res;
res.reserve(arrayAttr.size() - dropFront - dropBack);
for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
it != eit; ++it)
res.push_back((*it).getValue().getSExtValue());
return res;
}
// Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to
// BroadcastOp(ExtractStrideSliceOp).
class StridedSliceBroadcast final
: public OpRewritePattern<ExtractStridedSliceOp> {
public:
using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
PatternRewriter &rewriter) const override {
auto broadcast = op.vector().getDefiningOp<BroadcastOp>();
if (!broadcast)
return failure();
auto srcVecType = broadcast.source().getType().dyn_cast<VectorType>();
unsigned srcRrank = srcVecType ? srcVecType.getRank() : 0;
auto dstVecType = op.getType().cast<VectorType>();
unsigned dstRank = dstVecType.getRank();
unsigned rankDiff = dstRank - srcRrank;
// Check if the most inner dimensions of the source of the broacast are the
// same as the destination of the extract. If this is the case we can just
// use a broadcast as the original dimensions are untouched.
bool lowerDimMatch = true;
for (unsigned i = 0; i < srcRrank; i++) {
if (srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
lowerDimMatch = false;
break;
}
}
Value source = broadcast.source();
if (!lowerDimMatch) {
// The inner dimensions don't match, it means we need to extract from the
// source of the orignal broadcast and then broadcast the extracted value.
source = rewriter.create<ExtractStridedSliceOp>(
op->getLoc(), source,
getI64SubArray(op.offsets(), /* dropFront=*/rankDiff),
getI64SubArray(op.sizes(), /* dropFront=*/rankDiff),
getI64SubArray(op.strides(), /* dropFront=*/rankDiff));
}
rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), source);
return success();
}
};
} // end anonymous namespace
void ExtractStridedSliceOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
// Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
// ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
results.insert<StridedSliceConstantMaskFolder, StridedSliceConstantFolder>(
context);
results.insert<StridedSliceConstantMaskFolder, StridedSliceConstantFolder,
StridedSliceBroadcast>(context);
}
//===----------------------------------------------------------------------===//
@ -2652,10 +2739,12 @@ OpFoldResult ShapeCastOp::fold(ArrayRef<Attribute> operands) {
return source();
// Canceling shape casts.
if (auto otherOp = source().getDefiningOp<ShapeCastOp>())
if (auto otherOp = source().getDefiningOp<ShapeCastOp>()) {
if (result().getType() == otherOp.source().getType())
return otherOp.source();
setOperand(otherOp.source());
return getResult();
}
return {};
}

View File

@ -613,4 +613,51 @@ func @extract_strided_constant() -> (vector<12x2xf32>, vector<2x13x3xi32>) {
return %0, %1 : vector<12x2xf32>, vector<2x13x3xi32>
}
// -----
// CHECK-LABEL: extract_strided_broadcast
// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : vector<4xf16> to vector<2x4xf16>
// CHECK-NEXT: return %[[B]] : vector<2x4xf16>
func @extract_strided_broadcast(%arg0: vector<4xf16>) -> vector<2x4xf16> {
%0 = vector.broadcast %arg0 : vector<4xf16> to vector<16x4xf16>
%1 = vector.extract_strided_slice %0
{offsets = [0, 0], sizes = [2, 4], strides = [1, 1]} :
vector<16x4xf16> to vector<2x4xf16>
return %1 : vector<2x4xf16>
}
// -----
// CHECK-LABEL: extract_strided_broadcast2
// CHECK: %[[E:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0], sizes = [2], strides = [1]} : vector<4xf16> to vector<2xf16>
// CHECK-NEXT: %[[B:.*]] = vector.broadcast %[[E]] : vector<2xf16> to vector<2x2xf16>
// CHECK-NEXT: return %[[B]] : vector<2x2xf16>
func @extract_strided_broadcast2(%arg0: vector<4xf16>) -> vector<2x2xf16> {
%0 = vector.broadcast %arg0 : vector<4xf16> to vector<16x4xf16>
%1 = vector.extract_strided_slice %0
{offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} :
vector<16x4xf16> to vector<2x2xf16>
return %1 : vector<2x2xf16>
}
// -----
// CHECK-LABEL: consecutive_shape_cast
// CHECK: %[[C:.*]] = vector.shape_cast %{{.*}} : vector<16xf16> to vector<4x4xf16>
// CHECK-NEXT: return %[[C]] : vector<4x4xf16>
func @consecutive_shape_cast(%arg0: vector<16xf16>) -> vector<4x4xf16> {
%0 = vector.shape_cast %arg0 : vector<16xf16> to vector<2x8xf16>
%1 = vector.shape_cast %0 : vector<2x8xf16> to vector<4x4xf16>
return %1 : vector<4x4xf16>
}
// -----
// CHECK-LABEL: broadcast_to_shapecast
// CHECK: %[[C:.*]] = vector.shape_cast %{{.*}} : vector<4x4xf16> to vector<1x4x4xf16>
// CHECK-NEXT: return %[[C]] : vector<1x4x4xf16>
func @broadcast_to_shapecast(%arg0: vector<4x4xf16>) -> vector<1x4x4xf16> {
%0 = vector.broadcast %arg0 : vector<4x4xf16> to vector<1x4x4xf16>
return %0 : vector<1x4x4xf16>
}