forked from OSchip/llvm-project
[mlir][vector] Add canonicalization patterns for ExtractStride/ShapeCast + Splat constant
Differential Revision: https://reviews.llvm.org/D90567
This commit is contained in:
parent
e969ab4320
commit
36480657d8
|
@ -1649,6 +1649,7 @@ def Vector_ShapeCastOp :
|
|||
}];
|
||||
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($result)";
|
||||
let hasFolder = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def Vector_BitCastOp :
|
||||
|
|
|
@ -1770,13 +1770,39 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
// Pattern to rewrite a ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
|
||||
class StridedSliceConstantFolder final
|
||||
: public OpRewritePattern<ExtractStridedSliceOp> {
|
||||
public:
|
||||
using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Return if 'extractStridedSliceOp' operand is not defined by a
|
||||
// ConstantOp.
|
||||
auto constantOp =
|
||||
extractStridedSliceOp.vector().getDefiningOp<ConstantOp>();
|
||||
if (!constantOp)
|
||||
return failure();
|
||||
auto dense = constantOp.value().dyn_cast<SplatElementsAttr>();
|
||||
if (!dense)
|
||||
return failure();
|
||||
auto newAttr = DenseElementsAttr::get(
|
||||
extractStridedSliceOp.getType().cast<VectorType>(),
|
||||
dense.getSplatValue());
|
||||
rewriter.replaceOpWithNewOp<ConstantOp>(extractStridedSliceOp, newAttr);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // end anonymous namespace
|
||||
|
||||
void ExtractStridedSliceOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList &results, MLIRContext *context) {
|
||||
// Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
|
||||
// ConstantMaskOp.
|
||||
results.insert<StridedSliceConstantMaskFolder>(context);
|
||||
// ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
|
||||
results.insert<StridedSliceConstantMaskFolder, StridedSliceConstantFolder>(
|
||||
context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -2560,6 +2586,36 @@ OpFoldResult ShapeCastOp::fold(ArrayRef<Attribute> operands) {
|
|||
return {};
|
||||
}
|
||||
|
||||
namespace {
|
||||
// Pattern to rewrite a ShapeCast(splat ConstantOp) -> ConstantOp.
|
||||
class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
|
||||
public:
|
||||
using OpRewritePattern<ShapeCastOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto constantOp = shapeCastOp.source().getDefiningOp<ConstantOp>();
|
||||
if (!constantOp)
|
||||
return failure();
|
||||
// Only handle splat for now.
|
||||
auto dense = constantOp.value().dyn_cast<SplatElementsAttr>();
|
||||
if (!dense)
|
||||
return failure();
|
||||
auto newAttr = DenseElementsAttr::get(
|
||||
shapeCastOp.getType().cast<VectorType>(), dense.getSplatValue());
|
||||
rewriter.replaceOpWithNewOp<ConstantOp>(shapeCastOp, newAttr);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void ShapeCastOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
// Pattern to rewrite a ShapeCastOp(ConstantOp) -> ConstantOp.
|
||||
results.insert<ShapeCastConstantFolder>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// VectorBitCastOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -580,3 +580,37 @@ func @broadcast_folding2() -> vector<4x16xi32> {
|
|||
%2 = vector.broadcast %1 : vector<16xi32> to vector<4x16xi32>
|
||||
return %2 : vector<4x16xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: shape_cast_constant
|
||||
// CHECK: %[[CST0:.*]] = constant dense<2.000000e+00> : vector<20x2xf32>
|
||||
// CHECK: %[[CST1:.*]] = constant dense<1> : vector<3x4x2xi32>
|
||||
// CHECK: return %[[CST0]], %[[CST1]] : vector<20x2xf32>, vector<3x4x2xi32>
|
||||
func @shape_cast_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) {
|
||||
%cst = constant dense<2.000000e+00> : vector<5x4x2xf32>
|
||||
%cst_1 = constant dense<1> : vector<12x2xi32>
|
||||
%0 = vector.shape_cast %cst : vector<5x4x2xf32> to vector<20x2xf32>
|
||||
%1 = vector.shape_cast %cst_1 : vector<12x2xi32> to vector<3x4x2xi32>
|
||||
return %0, %1 : vector<20x2xf32>, vector<3x4x2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: extract_strided_constant
|
||||
// CHECK: %[[CST0:.*]] = constant dense<2.000000e+00> : vector<12x2xf32>
|
||||
// CHECK: %[[CST1:.*]] = constant dense<1> : vector<2x13x3xi32>
|
||||
// CHECK: return %[[CST0]], %[[CST1]] : vector<12x2xf32>, vector<2x13x3xi32>
|
||||
func @extract_strided_constant() -> (vector<12x2xf32>, vector<2x13x3xi32>) {
|
||||
%cst = constant dense<2.000000e+00> : vector<29x7xf32>
|
||||
%cst_1 = constant dense<1> : vector<4x37x9xi32>
|
||||
%0 = vector.extract_strided_slice %cst
|
||||
{offsets = [2, 3], sizes = [12, 2], strides = [1, 1]}
|
||||
: vector<29x7xf32> to vector<12x2xf32>
|
||||
%1 = vector.extract_strided_slice %cst_1
|
||||
{offsets = [1, 2, 5], sizes = [2, 13, 3], strides = [1, 1, 1]}
|
||||
: vector<4x37x9xi32> to vector<2x13x3xi32>
|
||||
return %0, %1 : vector<12x2xf32>, vector<2x13x3xi32>
|
||||
}
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue