forked from OSchip/llvm-project
[mlir] Add canonicalization for tensor_cast + tensor_to_memref
This helps bufferization passes by removing tensor_cast operations. Differential Revision: https://reviews.llvm.org/D96745
This commit is contained in:
parent
cb1a42359b
commit
807e5467f3
|
@ -3078,6 +3078,7 @@ def TensorToMemrefOp : Std_Op<"tensor_to_memref",
|
|||
let assemblyFormat = "$tensor attr-dict `:` type($memref)";
|
||||
|
||||
let hasFolder = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -3558,6 +3558,37 @@ OpFoldResult TensorToMemrefOp::fold(ArrayRef<Attribute>) {
|
|||
return {};
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// Replace tensor_cast + tensor_to_memref by tensor_to_memref + memref_cast.
|
||||
struct TensorCastToMemref : public OpRewritePattern<TensorToMemrefOp> {
|
||||
using OpRewritePattern<TensorToMemrefOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(TensorToMemrefOp tensorToMemRef,
|
||||
PatternRewriter &rewriter) const final {
|
||||
auto tensorCastOperand =
|
||||
tensorToMemRef.getOperand().getDefiningOp<tensor::CastOp>();
|
||||
if (!tensorCastOperand)
|
||||
return failure();
|
||||
auto srcTensorType =
|
||||
tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>();
|
||||
if (!srcTensorType)
|
||||
return failure();
|
||||
auto memrefType = MemRefType::get(srcTensorType.getShape(),
|
||||
srcTensorType.getElementType());
|
||||
Value memref = rewriter.create<TensorToMemrefOp>(
|
||||
tensorToMemRef.getLoc(), memrefType, tensorCastOperand.getOperand());
|
||||
rewriter.replaceOpWithNewOp<MemRefCastOp>(tensorToMemRef,
|
||||
tensorToMemRef.getType(), memref);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void TensorToMemrefOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList &results, MLIRContext *context) {
|
||||
results.insert<TensorCastToMemref>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TransposeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -131,3 +131,15 @@ func @fold_dim_of_tensor.cast(%arg0 : tensor<4x?xf32>) -> (index, index) {
|
|||
%2 = dim %0, %c1 : tensor<?x?xf32>
|
||||
return %1, %2: index, index
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @tensor_cast_to_memref
|
||||
// CHECK-SAME: %[[ARG0:.+]]: tensor<4x6x16x32xi8>
|
||||
// CHECK: %[[M:.+]] = tensor_to_memref %[[ARG0]] : memref<4x6x16x32xi8>
|
||||
// CHECK: %[[M1:.+]] = memref_cast %[[M]] : memref<4x6x16x32xi8> to memref<?x?x16x32xi8>
|
||||
// CHECK: return %[[M1]] : memref<?x?x16x32xi8>
|
||||
func @tensor_cast_to_memref(%arg0 : tensor<4x6x16x32xi8>) ->
|
||||
memref<?x?x16x32xi8> {
|
||||
%0 = tensor.cast %arg0 : tensor<4x6x16x32xi8> to tensor<?x?x16x32xi8>
|
||||
%1 = tensor_to_memref %0 : memref<?x?x16x32xi8>
|
||||
return %1 : memref<?x?x16x32xi8>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue