forked from OSchip/llvm-project
[mlir][linalg] Introduce a separate EraseIdentityCopyOp Pattern.
Split out an EraseIdentityCopyOp from the existing RemoveIdentityLinalgOps pattern. Introduce an additional check to ensure the pattern checks the permutation maps match. This is a preparation step to specialize RemoveIdentityLinalgOps to GenericOp only. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D105622
This commit is contained in:
parent
4fd42e2e80
commit
ca0d244e99
|
@ -168,6 +168,7 @@ def CopyOp : LinalgStructured_Op<"copy", [CopyOpInterface]> {
|
|||
custom<CopyOpRegion>($region, ref(type($input)), ref(type($input)))
|
||||
}];
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
let hasFolder = 1;
|
||||
let skipDefaultBuilders = 1;
|
||||
}
|
||||
|
|
|
@ -426,6 +426,31 @@ void CopyOp::getEffects(
|
|||
SideEffects::DefaultResource::get());
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// Remove copy operations that copy data inplace. Requirements are:
|
||||
/// 1) The input and output values are identical.
|
||||
/// 2) The input and output permutation maps are identical.
|
||||
struct EraseIdentityCopyOp : public OpRewritePattern<CopyOp> {
|
||||
using OpRewritePattern<CopyOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(CopyOp copyOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
assert(copyOp.hasBufferSemantics());
|
||||
if (copyOp.input() == copyOp.output() &&
|
||||
copyOp.inputPermutation() == copyOp.outputPermutation()) {
|
||||
rewriter.eraseOp(copyOp);
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
MLIRContext *context) {
|
||||
results.add<EraseIdentityCopyOp>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// FillOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -2615,15 +2640,6 @@ struct RemoveIdentityLinalgOps : public OpInterfaceRewritePattern<LinalgOp> {
|
|||
|
||||
LogicalResult matchAndRewrite(LinalgOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (auto copyOp = dyn_cast<CopyOp>(*op)) {
|
||||
assert(copyOp.hasBufferSemantics());
|
||||
if (copyOp.input() == copyOp.output() &&
|
||||
copyOp.inputPermutation() == copyOp.outputPermutation()) {
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
}
|
||||
|
||||
if (!isa<GenericOp>(op))
|
||||
return failure();
|
||||
if (!op.hasTensorSemantics())
|
||||
|
|
|
@ -661,6 +661,20 @@ func @self_copy(%arg0 : memref<2x3x?x4xf32>) {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @self_copy_with_permutation
|
||||
func @self_copy_with_permutation(%arg0 : memref<2x3x?x4xf32>) {
|
||||
|
||||
// CHECK: linalg.copy
|
||||
linalg.copy(%arg0, %arg0)
|
||||
{inputPermutation = affine_map<(i, j, k, l) -> (j, k, i, l)>,
|
||||
outputPermuation = affine_map<(i, j, k, l) -> (i, j, k, l)>} : memref<2x3x?x4xf32>, memref<2x3x?x4xf32>
|
||||
|
||||
// CHECK: return
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @fold_fill_reshape()
|
||||
func @fold_fill_reshape() -> tensor<6x4xf32> {
|
||||
%zero = constant 0.0 : f32
|
||||
|
|
Loading…
Reference in New Issue