[mlir][Linalg] Add GenericOp self-copy on buffers folding

Reviewed By: pifon2a

Differential Revision: https://reviews.llvm.org/D118116
This commit is contained in:
Nicolas Vasilache 2022-01-26 05:19:53 -05:00
parent ed4efee2a3
commit 9b6c2ea302
3 changed files with 30 additions and 3 deletions

View File

@ -843,8 +843,6 @@ struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
if (!genericOp.hasTensorSemantics())
return failure();
// Check all indexing maps are identity.
if (llvm::any_of(genericOp.getIndexingMaps(),
[](AffineMap map) { return !map.isIdentity(); }))
@ -859,6 +857,17 @@ struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
if (!yieldOp)
return failure();
// In the buffer case, we need to check exact buffer equality.
if (genericOp.hasBufferSemantics()) {
if (genericOp.getNumInputs() == 1 && genericOp.getNumOutputs() == 1 &&
genericOp.getInputOperand(0)->get() ==
genericOp.getOutputOperand(0)->get()) {
rewriter.eraseOp(genericOp);
return success();
}
return failure();
}
// Get the argument number of the returned values. That is the operand
// number to use for replacing uses of this operation.
SmallVector<Value> returnedArgs;
@ -876,6 +885,7 @@ struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
resultType, returnedArg);
returnedArgs.push_back(returnedArg);
}
if (returnedArgs.size() != genericOp->getNumResults())
return failure();
rewriter.replaceOp(genericOp, returnedArgs);

View File

@ -583,3 +583,19 @@ func @dim_of_tiled_loop_result_no_canonicalize(%arg0: tensor<?x?xf32>, %arg1: te
%r2 = tensor.dim %r, %c0 : tensor<?x?xf32>
return %r2 : index
}
// -----
// CHECK: func @fold_self_copy
func @fold_self_copy(%0 : memref<4x16xf32>) {
// CHECK-NEXT: return
linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%0 : memref<4x16xf32>)
outs(%0 : memref<4x16xf32>) {
^bb0(%arg4: f32, %arg5: f32):
linalg.yield %arg4 : f32
}
return
}

View File

@ -25,7 +25,8 @@ func @inlined_fn(%arg0: memref<?xf32>) {
ins(%arg0 : memref<?xf32>)
outs(%arg0 : memref<?xf32>) {
^bb(%0 : f32, %1 : f32) :
linalg.yield %0 : f32
%2 = arith.addf %0, %0: f32
linalg.yield %2 : f32
}
return
}