[mlir] Add canonicalization for tensor_cast + tensor_to_memref

This helps bufferization passes by removing tensor_cast operations.

Differential Revision: https://reviews.llvm.org/D96745
This commit is contained in:
Thomas Raoux 2021-02-15 21:10:07 -08:00
parent cb1a42359b
commit 807e5467f3
3 changed files with 44 additions and 0 deletions

View File

@ -3078,6 +3078,7 @@ def TensorToMemrefOp : Std_Op<"tensor_to_memref",
let assemblyFormat = "$tensor attr-dict `:` type($memref)"; let assemblyFormat = "$tensor attr-dict `:` type($memref)";
let hasFolder = 1; let hasFolder = 1;
let hasCanonicalizer = 1;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -3558,6 +3558,37 @@ OpFoldResult TensorToMemrefOp::fold(ArrayRef<Attribute>) {
return {}; return {};
} }
namespace {
/// Replace tensor_cast + tensor_to_memref by tensor_to_memref + memref_cast.
struct TensorCastToMemref : public OpRewritePattern<TensorToMemrefOp> {
using OpRewritePattern<TensorToMemrefOp>::OpRewritePattern;
LogicalResult matchAndRewrite(TensorToMemrefOp tensorToMemRef,
PatternRewriter &rewriter) const final {
auto tensorCastOperand =
tensorToMemRef.getOperand().getDefiningOp<tensor::CastOp>();
if (!tensorCastOperand)
return failure();
auto srcTensorType =
tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>();
if (!srcTensorType)
return failure();
auto memrefType = MemRefType::get(srcTensorType.getShape(),
srcTensorType.getElementType());
Value memref = rewriter.create<TensorToMemrefOp>(
tensorToMemRef.getLoc(), memrefType, tensorCastOperand.getOperand());
rewriter.replaceOpWithNewOp<MemRefCastOp>(tensorToMemRef,
tensorToMemRef.getType(), memref);
return success();
}
};
} // namespace
void TensorToMemrefOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<TensorCastToMemref>(context);
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// TransposeOp // TransposeOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -131,3 +131,15 @@ func @fold_dim_of_tensor.cast(%arg0 : tensor<4x?xf32>) -> (index, index) {
%2 = dim %0, %c1 : tensor<?x?xf32> %2 = dim %0, %c1 : tensor<?x?xf32>
return %1, %2: index, index return %1, %2: index, index
} }
// CHECK-LABEL: func @tensor_cast_to_memref
// CHECK-SAME: %[[ARG0:.+]]: tensor<4x6x16x32xi8>
// CHECK: %[[M:.+]] = tensor_to_memref %[[ARG0]] : memref<4x6x16x32xi8>
// CHECK: %[[M1:.+]] = memref_cast %[[M]] : memref<4x6x16x32xi8> to memref<?x?x16x32xi8>
// CHECK: return %[[M1]] : memref<?x?x16x32xi8>
func @tensor_cast_to_memref(%arg0 : tensor<4x6x16x32xi8>) ->
memref<?x?x16x32xi8> {
%0 = tensor.cast %arg0 : tensor<4x6x16x32xi8> to tensor<?x?x16x32xi8>
%1 = tensor_to_memref %0 : memref<?x?x16x32xi8>
return %1 : memref<?x?x16x32xi8>
}