forked from OSchip/llvm-project
[mlir][vector] Add canonicalization patterns for ExtractStride/ShapeCast + Splat constant
Differential Revision: https://reviews.llvm.org/D90567
This commit is contained in:
parent
e969ab4320
commit
36480657d8
|
@ -1649,6 +1649,7 @@ def Vector_ShapeCastOp :
|
||||||
}];
|
}];
|
||||||
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($result)";
|
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($result)";
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
|
let hasCanonicalizer = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def Vector_BitCastOp :
|
def Vector_BitCastOp :
|
||||||
|
|
|
@ -1770,13 +1770,39 @@ public:
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Pattern to rewrite a ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
|
||||||
|
class StridedSliceConstantFolder final
|
||||||
|
: public OpRewritePattern<ExtractStridedSliceOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
// Return if 'extractStridedSliceOp' operand is not defined by a
|
||||||
|
// ConstantOp.
|
||||||
|
auto constantOp =
|
||||||
|
extractStridedSliceOp.vector().getDefiningOp<ConstantOp>();
|
||||||
|
if (!constantOp)
|
||||||
|
return failure();
|
||||||
|
auto dense = constantOp.value().dyn_cast<SplatElementsAttr>();
|
||||||
|
if (!dense)
|
||||||
|
return failure();
|
||||||
|
auto newAttr = DenseElementsAttr::get(
|
||||||
|
extractStridedSliceOp.getType().cast<VectorType>(),
|
||||||
|
dense.getSplatValue());
|
||||||
|
rewriter.replaceOpWithNewOp<ConstantOp>(extractStridedSliceOp, newAttr);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
} // end anonymous namespace
|
} // end anonymous namespace
|
||||||
|
|
||||||
void ExtractStridedSliceOp::getCanonicalizationPatterns(
|
void ExtractStridedSliceOp::getCanonicalizationPatterns(
|
||||||
OwningRewritePatternList &results, MLIRContext *context) {
|
OwningRewritePatternList &results, MLIRContext *context) {
|
||||||
// Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
|
// Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
|
||||||
// ConstantMaskOp.
|
// ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
|
||||||
results.insert<StridedSliceConstantMaskFolder>(context);
|
results.insert<StridedSliceConstantMaskFolder, StridedSliceConstantFolder>(
|
||||||
|
context);
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -2560,6 +2586,36 @@ OpFoldResult ShapeCastOp::fold(ArrayRef<Attribute> operands) {
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// Pattern to rewrite a ShapeCast(splat ConstantOp) -> ConstantOp.
|
||||||
|
class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern<ShapeCastOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
auto constantOp = shapeCastOp.source().getDefiningOp<ConstantOp>();
|
||||||
|
if (!constantOp)
|
||||||
|
return failure();
|
||||||
|
// Only handle splat for now.
|
||||||
|
auto dense = constantOp.value().dyn_cast<SplatElementsAttr>();
|
||||||
|
if (!dense)
|
||||||
|
return failure();
|
||||||
|
auto newAttr = DenseElementsAttr::get(
|
||||||
|
shapeCastOp.getType().cast<VectorType>(), dense.getSplatValue());
|
||||||
|
rewriter.replaceOpWithNewOp<ConstantOp>(shapeCastOp, newAttr);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void ShapeCastOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||||
|
MLIRContext *context) {
|
||||||
|
// Pattern to rewrite a ShapeCastOp(ConstantOp) -> ConstantOp.
|
||||||
|
results.insert<ShapeCastConstantFolder>(context);
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// VectorBitCastOp
|
// VectorBitCastOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -580,3 +580,37 @@ func @broadcast_folding2() -> vector<4x16xi32> {
|
||||||
%2 = vector.broadcast %1 : vector<16xi32> to vector<4x16xi32>
|
%2 = vector.broadcast %1 : vector<16xi32> to vector<4x16xi32>
|
||||||
return %2 : vector<4x16xi32>
|
return %2 : vector<4x16xi32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: shape_cast_constant
|
||||||
|
// CHECK: %[[CST0:.*]] = constant dense<2.000000e+00> : vector<20x2xf32>
|
||||||
|
// CHECK: %[[CST1:.*]] = constant dense<1> : vector<3x4x2xi32>
|
||||||
|
// CHECK: return %[[CST0]], %[[CST1]] : vector<20x2xf32>, vector<3x4x2xi32>
|
||||||
|
func @shape_cast_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) {
|
||||||
|
%cst = constant dense<2.000000e+00> : vector<5x4x2xf32>
|
||||||
|
%cst_1 = constant dense<1> : vector<12x2xi32>
|
||||||
|
%0 = vector.shape_cast %cst : vector<5x4x2xf32> to vector<20x2xf32>
|
||||||
|
%1 = vector.shape_cast %cst_1 : vector<12x2xi32> to vector<3x4x2xi32>
|
||||||
|
return %0, %1 : vector<20x2xf32>, vector<3x4x2xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: extract_strided_constant
|
||||||
|
// CHECK: %[[CST0:.*]] = constant dense<2.000000e+00> : vector<12x2xf32>
|
||||||
|
// CHECK: %[[CST1:.*]] = constant dense<1> : vector<2x13x3xi32>
|
||||||
|
// CHECK: return %[[CST0]], %[[CST1]] : vector<12x2xf32>, vector<2x13x3xi32>
|
||||||
|
func @extract_strided_constant() -> (vector<12x2xf32>, vector<2x13x3xi32>) {
|
||||||
|
%cst = constant dense<2.000000e+00> : vector<29x7xf32>
|
||||||
|
%cst_1 = constant dense<1> : vector<4x37x9xi32>
|
||||||
|
%0 = vector.extract_strided_slice %cst
|
||||||
|
{offsets = [2, 3], sizes = [12, 2], strides = [1, 1]}
|
||||||
|
: vector<29x7xf32> to vector<12x2xf32>
|
||||||
|
%1 = vector.extract_strided_slice %cst_1
|
||||||
|
{offsets = [1, 2, 5], sizes = [2, 13, 3], strides = [1, 1, 1]}
|
||||||
|
: vector<4x37x9xi32> to vector<2x13x3xi32>
|
||||||
|
return %0, %1 : vector<12x2xf32>, vector<2x13x3xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue