[mlir][linalg] Introduce a separate EraseIdentityCopyOp Pattern.

Split out an EraseIdentityCopyOp from the existing RemoveIdentityLinalgOps pattern. Introduce an additional check to ensure the pattern checks the permutation maps match. This is a preparation step to specialize RemoveIdentityLinalgOps to GenericOp only.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D105622
This commit is contained in:
Tobias Gysi 2021-07-28 09:42:01 +00:00
parent 4fd42e2e80
commit ca0d244e99
3 changed files with 40 additions and 9 deletions

View File

@ -168,6 +168,7 @@ def CopyOp : LinalgStructured_Op<"copy", [CopyOpInterface]> {
custom<CopyOpRegion>($region, ref(type($input)), ref(type($input)))
}];
let hasCanonicalizer = 1;
let hasFolder = 1;
let skipDefaultBuilders = 1;
}

View File

@ -426,6 +426,31 @@ void CopyOp::getEffects(
SideEffects::DefaultResource::get());
}
namespace {
/// Remove copy operations that copy data inplace. Requirements are:
/// 1) The input and output values are identical.
/// 2) The input and output permutation maps are identical.
struct EraseIdentityCopyOp : public OpRewritePattern<CopyOp> {
using OpRewritePattern<CopyOp>::OpRewritePattern;
LogicalResult matchAndRewrite(CopyOp copyOp,
PatternRewriter &rewriter) const override {
assert(copyOp.hasBufferSemantics());
if (copyOp.input() == copyOp.output() &&
copyOp.inputPermutation() == copyOp.outputPermutation()) {
rewriter.eraseOp(copyOp);
return success();
}
return failure();
}
};
} // namespace
void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<EraseIdentityCopyOp>(context);
}
//===----------------------------------------------------------------------===//
// FillOp
//===----------------------------------------------------------------------===//
@ -2615,15 +2640,6 @@ struct RemoveIdentityLinalgOps : public OpInterfaceRewritePattern<LinalgOp> {
LogicalResult matchAndRewrite(LinalgOp op,
PatternRewriter &rewriter) const override {
if (auto copyOp = dyn_cast<CopyOp>(*op)) {
assert(copyOp.hasBufferSemantics());
if (copyOp.input() == copyOp.output() &&
copyOp.inputPermutation() == copyOp.outputPermutation()) {
rewriter.eraseOp(op);
return success();
}
}
if (!isa<GenericOp>(op))
return failure();
if (!op.hasTensorSemantics())

View File

@ -661,6 +661,20 @@ func @self_copy(%arg0 : memref<2x3x?x4xf32>) {
// -----
// CHECK-LABEL: @self_copy_with_permutation
func @self_copy_with_permutation(%arg0 : memref<2x3x?x4xf32>) {
// CHECK: linalg.copy
linalg.copy(%arg0, %arg0)
{inputPermutation = affine_map<(i, j, k, l) -> (j, k, i, l)>,
outputPermuation = affine_map<(i, j, k, l) -> (i, j, k, l)>} : memref<2x3x?x4xf32>, memref<2x3x?x4xf32>
// CHECK: return
return
}
// -----
// CHECK-LABEL: func @fold_fill_reshape()
func @fold_fill_reshape() -> tensor<6x4xf32> {
%zero = constant 0.0 : f32