forked from OSchip/llvm-project
[mlir][vector] Add more vector Ops canonicalization
Add canonicalization for BroadcastOp, ExtractStrideSlicesOp and ShapeCastOp Differential Revision: https://reviews.llvm.org/D93120
This commit is contained in:
parent
1876a2914f
commit
74186880ba
|
@ -271,6 +271,7 @@ def Vector_BroadcastOp :
|
|||
}];
|
||||
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($vector)";
|
||||
let hasFolder = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def Vector_ShuffleOp :
|
||||
|
|
|
@ -1110,6 +1110,36 @@ OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
|
|||
return {};
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// BroadcastOp can only add dimensions or broadcast a dimension from 1 to N. In
|
||||
// the degenerated case where the broadcast only adds dimensions of size 1 it
|
||||
// can be replaced by a ShapeCastOp. This canonicalization checks if the total
|
||||
// number of elements is the same before and after the broadcast to detect if
|
||||
// the only change in the vector type are new dimensions of size 1.
|
||||
class BroadcastToShapeCast final : public OpRewritePattern<BroadcastOp> {
|
||||
public:
|
||||
using OpRewritePattern<BroadcastOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto srcVecType = broadcastOp.getSourceType().dyn_cast<VectorType>();
|
||||
if (!srcVecType || broadcastOp.getVectorType().getNumElements() !=
|
||||
srcVecType.getNumElements())
|
||||
return failure();
|
||||
rewriter.replaceOpWithNewOp<ShapeCastOp>(
|
||||
broadcastOp, broadcastOp.getVectorType(), broadcastOp.source());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void BroadcastOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList &results, MLIRContext *context) {
|
||||
results.insert<BroadcastToShapeCast>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ShuffleOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1768,7 +1798,8 @@ void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
|
|||
|
||||
namespace {
|
||||
|
||||
// Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) -> ConstantMaskOp.
|
||||
// Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to
|
||||
// ConstantMaskOp.
|
||||
class StridedSliceConstantMaskFolder final
|
||||
: public OpRewritePattern<ExtractStridedSliceOp> {
|
||||
public:
|
||||
|
@ -1847,14 +1878,70 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
// Helper that returns a subset of `arrayAttr` as a vector of int64_t.
|
||||
static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
|
||||
unsigned dropFront = 0,
|
||||
unsigned dropBack = 0) {
|
||||
assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
|
||||
auto range = arrayAttr.getAsRange<IntegerAttr>();
|
||||
SmallVector<int64_t, 4> res;
|
||||
res.reserve(arrayAttr.size() - dropFront - dropBack);
|
||||
for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
|
||||
it != eit; ++it)
|
||||
res.push_back((*it).getValue().getSExtValue());
|
||||
return res;
|
||||
}
|
||||
|
||||
// Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to
|
||||
// BroadcastOp(ExtractStrideSliceOp).
|
||||
class StridedSliceBroadcast final
|
||||
: public OpRewritePattern<ExtractStridedSliceOp> {
|
||||
public:
|
||||
using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto broadcast = op.vector().getDefiningOp<BroadcastOp>();
|
||||
if (!broadcast)
|
||||
return failure();
|
||||
auto srcVecType = broadcast.source().getType().dyn_cast<VectorType>();
|
||||
unsigned srcRrank = srcVecType ? srcVecType.getRank() : 0;
|
||||
auto dstVecType = op.getType().cast<VectorType>();
|
||||
unsigned dstRank = dstVecType.getRank();
|
||||
unsigned rankDiff = dstRank - srcRrank;
|
||||
// Check if the most inner dimensions of the source of the broacast are the
|
||||
// same as the destination of the extract. If this is the case we can just
|
||||
// use a broadcast as the original dimensions are untouched.
|
||||
bool lowerDimMatch = true;
|
||||
for (unsigned i = 0; i < srcRrank; i++) {
|
||||
if (srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
|
||||
lowerDimMatch = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
Value source = broadcast.source();
|
||||
if (!lowerDimMatch) {
|
||||
// The inner dimensions don't match, it means we need to extract from the
|
||||
// source of the orignal broadcast and then broadcast the extracted value.
|
||||
source = rewriter.create<ExtractStridedSliceOp>(
|
||||
op->getLoc(), source,
|
||||
getI64SubArray(op.offsets(), /* dropFront=*/rankDiff),
|
||||
getI64SubArray(op.sizes(), /* dropFront=*/rankDiff),
|
||||
getI64SubArray(op.strides(), /* dropFront=*/rankDiff));
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), source);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // end anonymous namespace
|
||||
|
||||
void ExtractStridedSliceOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList &results, MLIRContext *context) {
|
||||
// Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
|
||||
// ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
|
||||
results.insert<StridedSliceConstantMaskFolder, StridedSliceConstantFolder>(
|
||||
context);
|
||||
results.insert<StridedSliceConstantMaskFolder, StridedSliceConstantFolder,
|
||||
StridedSliceBroadcast>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -2652,10 +2739,12 @@ OpFoldResult ShapeCastOp::fold(ArrayRef<Attribute> operands) {
|
|||
return source();
|
||||
|
||||
// Canceling shape casts.
|
||||
if (auto otherOp = source().getDefiningOp<ShapeCastOp>())
|
||||
if (auto otherOp = source().getDefiningOp<ShapeCastOp>()) {
|
||||
if (result().getType() == otherOp.source().getType())
|
||||
return otherOp.source();
|
||||
|
||||
setOperand(otherOp.source());
|
||||
return getResult();
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
|
|
|
@ -613,4 +613,51 @@ func @extract_strided_constant() -> (vector<12x2xf32>, vector<2x13x3xi32>) {
|
|||
return %0, %1 : vector<12x2xf32>, vector<2x13x3xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: extract_strided_broadcast
|
||||
// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : vector<4xf16> to vector<2x4xf16>
|
||||
// CHECK-NEXT: return %[[B]] : vector<2x4xf16>
|
||||
func @extract_strided_broadcast(%arg0: vector<4xf16>) -> vector<2x4xf16> {
|
||||
%0 = vector.broadcast %arg0 : vector<4xf16> to vector<16x4xf16>
|
||||
%1 = vector.extract_strided_slice %0
|
||||
{offsets = [0, 0], sizes = [2, 4], strides = [1, 1]} :
|
||||
vector<16x4xf16> to vector<2x4xf16>
|
||||
return %1 : vector<2x4xf16>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: extract_strided_broadcast2
|
||||
// CHECK: %[[E:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0], sizes = [2], strides = [1]} : vector<4xf16> to vector<2xf16>
|
||||
// CHECK-NEXT: %[[B:.*]] = vector.broadcast %[[E]] : vector<2xf16> to vector<2x2xf16>
|
||||
// CHECK-NEXT: return %[[B]] : vector<2x2xf16>
|
||||
func @extract_strided_broadcast2(%arg0: vector<4xf16>) -> vector<2x2xf16> {
|
||||
%0 = vector.broadcast %arg0 : vector<4xf16> to vector<16x4xf16>
|
||||
%1 = vector.extract_strided_slice %0
|
||||
{offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} :
|
||||
vector<16x4xf16> to vector<2x2xf16>
|
||||
return %1 : vector<2x2xf16>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: consecutive_shape_cast
|
||||
// CHECK: %[[C:.*]] = vector.shape_cast %{{.*}} : vector<16xf16> to vector<4x4xf16>
|
||||
// CHECK-NEXT: return %[[C]] : vector<4x4xf16>
|
||||
func @consecutive_shape_cast(%arg0: vector<16xf16>) -> vector<4x4xf16> {
|
||||
%0 = vector.shape_cast %arg0 : vector<16xf16> to vector<2x8xf16>
|
||||
%1 = vector.shape_cast %0 : vector<2x8xf16> to vector<4x4xf16>
|
||||
return %1 : vector<4x4xf16>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: broadcast_to_shapecast
|
||||
// CHECK: %[[C:.*]] = vector.shape_cast %{{.*}} : vector<4x4xf16> to vector<1x4x4xf16>
|
||||
// CHECK-NEXT: return %[[C]] : vector<1x4x4xf16>
|
||||
func @broadcast_to_shapecast(%arg0: vector<4x4xf16>) -> vector<1x4x4xf16> {
|
||||
%0 = vector.broadcast %arg0 : vector<4x4xf16> to vector<1x4x4xf16>
|
||||
return %0 : vector<1x4x4xf16>
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue