[mlir][vector] Add canonicalization patterns for ExtractStride/ShapeCast + Splat constant

Differential Revision: https://reviews.llvm.org/D90567
This commit is contained in:
Thomas Raoux 2020-11-03 10:56:22 -08:00
parent e969ab4320
commit 36480657d8
3 changed files with 93 additions and 2 deletions

View File

@ -1649,6 +1649,7 @@ def Vector_ShapeCastOp :
}]; }];
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($result)"; let assemblyFormat = "$source attr-dict `:` type($source) `to` type($result)";
let hasFolder = 1; let hasFolder = 1;
let hasCanonicalizer = 1;
} }
def Vector_BitCastOp : def Vector_BitCastOp :

View File

@ -1770,13 +1770,39 @@ public:
} }
}; };
// Pattern to rewrite a ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
class StridedSliceConstantFolder final
: public OpRewritePattern<ExtractStridedSliceOp> {
public:
using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
PatternRewriter &rewriter) const override {
// Return if 'extractStridedSliceOp' operand is not defined by a
// ConstantOp.
auto constantOp =
extractStridedSliceOp.vector().getDefiningOp<ConstantOp>();
if (!constantOp)
return failure();
auto dense = constantOp.value().dyn_cast<SplatElementsAttr>();
if (!dense)
return failure();
auto newAttr = DenseElementsAttr::get(
extractStridedSliceOp.getType().cast<VectorType>(),
dense.getSplatValue());
rewriter.replaceOpWithNewOp<ConstantOp>(extractStridedSliceOp, newAttr);
return success();
}
};
} // end anonymous namespace } // end anonymous namespace
void ExtractStridedSliceOp::getCanonicalizationPatterns( void ExtractStridedSliceOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) { OwningRewritePatternList &results, MLIRContext *context) {
// Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) -> // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
// ConstantMaskOp. // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
results.insert<StridedSliceConstantMaskFolder>(context); results.insert<StridedSliceConstantMaskFolder, StridedSliceConstantFolder>(
context);
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -2560,6 +2586,36 @@ OpFoldResult ShapeCastOp::fold(ArrayRef<Attribute> operands) {
return {}; return {};
} }
namespace {
// Pattern to rewrite a ShapeCast(splat ConstantOp) -> ConstantOp.
class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
public:
using OpRewritePattern<ShapeCastOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
PatternRewriter &rewriter) const override {
auto constantOp = shapeCastOp.source().getDefiningOp<ConstantOp>();
if (!constantOp)
return failure();
// Only handle splat for now.
auto dense = constantOp.value().dyn_cast<SplatElementsAttr>();
if (!dense)
return failure();
auto newAttr = DenseElementsAttr::get(
shapeCastOp.getType().cast<VectorType>(), dense.getSplatValue());
rewriter.replaceOpWithNewOp<ConstantOp>(shapeCastOp, newAttr);
return success();
}
};
} // namespace
void ShapeCastOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
// Pattern to rewrite a ShapeCastOp(ConstantOp) -> ConstantOp.
results.insert<ShapeCastConstantFolder>(context);
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// VectorBitCastOp // VectorBitCastOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -580,3 +580,37 @@ func @broadcast_folding2() -> vector<4x16xi32> {
%2 = vector.broadcast %1 : vector<16xi32> to vector<4x16xi32> %2 = vector.broadcast %1 : vector<16xi32> to vector<4x16xi32>
return %2 : vector<4x16xi32> return %2 : vector<4x16xi32>
} }
// -----
// CHECK-LABEL: shape_cast_constant
// CHECK: %[[CST0:.*]] = constant dense<2.000000e+00> : vector<20x2xf32>
// CHECK: %[[CST1:.*]] = constant dense<1> : vector<3x4x2xi32>
// CHECK: return %[[CST0]], %[[CST1]] : vector<20x2xf32>, vector<3x4x2xi32>
func @shape_cast_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) {
%cst = constant dense<2.000000e+00> : vector<5x4x2xf32>
%cst_1 = constant dense<1> : vector<12x2xi32>
%0 = vector.shape_cast %cst : vector<5x4x2xf32> to vector<20x2xf32>
%1 = vector.shape_cast %cst_1 : vector<12x2xi32> to vector<3x4x2xi32>
return %0, %1 : vector<20x2xf32>, vector<3x4x2xi32>
}
// -----
// CHECK-LABEL: extract_strided_constant
// CHECK: %[[CST0:.*]] = constant dense<2.000000e+00> : vector<12x2xf32>
// CHECK: %[[CST1:.*]] = constant dense<1> : vector<2x13x3xi32>
// CHECK: return %[[CST0]], %[[CST1]] : vector<12x2xf32>, vector<2x13x3xi32>
func @extract_strided_constant() -> (vector<12x2xf32>, vector<2x13x3xi32>) {
%cst = constant dense<2.000000e+00> : vector<29x7xf32>
%cst_1 = constant dense<1> : vector<4x37x9xi32>
%0 = vector.extract_strided_slice %cst
{offsets = [2, 3], sizes = [12, 2], strides = [1, 1]}
: vector<29x7xf32> to vector<12x2xf32>
%1 = vector.extract_strided_slice %cst_1
{offsets = [1, 2, 5], sizes = [2, 13, 3], strides = [1, 1, 1]}
: vector<4x37x9xi32> to vector<2x13x3xi32>
return %0, %1 : vector<12x2xf32>, vector<2x13x3xi32>
}