forked from OSchip/llvm-project
[mlir] Add canonicalization for tensor_cast + tensor_to_memref
This helps bufferization passes by removing tensor_cast operations. Differential Revision: https://reviews.llvm.org/D96745
This commit is contained in:
parent
cb1a42359b
commit
807e5467f3
|
@ -3078,6 +3078,7 @@ def TensorToMemrefOp : Std_Op<"tensor_to_memref",
|
||||||
let assemblyFormat = "$tensor attr-dict `:` type($memref)";
|
let assemblyFormat = "$tensor attr-dict `:` type($memref)";
|
||||||
|
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
|
let hasCanonicalizer = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -3558,6 +3558,37 @@ OpFoldResult TensorToMemrefOp::fold(ArrayRef<Attribute>) {
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
/// Replace tensor_cast + tensor_to_memref by tensor_to_memref + memref_cast.
|
||||||
|
struct TensorCastToMemref : public OpRewritePattern<TensorToMemrefOp> {
|
||||||
|
using OpRewritePattern<TensorToMemrefOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(TensorToMemrefOp tensorToMemRef,
|
||||||
|
PatternRewriter &rewriter) const final {
|
||||||
|
auto tensorCastOperand =
|
||||||
|
tensorToMemRef.getOperand().getDefiningOp<tensor::CastOp>();
|
||||||
|
if (!tensorCastOperand)
|
||||||
|
return failure();
|
||||||
|
auto srcTensorType =
|
||||||
|
tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>();
|
||||||
|
if (!srcTensorType)
|
||||||
|
return failure();
|
||||||
|
auto memrefType = MemRefType::get(srcTensorType.getShape(),
|
||||||
|
srcTensorType.getElementType());
|
||||||
|
Value memref = rewriter.create<TensorToMemrefOp>(
|
||||||
|
tensorToMemRef.getLoc(), memrefType, tensorCastOperand.getOperand());
|
||||||
|
rewriter.replaceOpWithNewOp<MemRefCastOp>(tensorToMemRef,
|
||||||
|
tensorToMemRef.getType(), memref);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void TensorToMemrefOp::getCanonicalizationPatterns(
|
||||||
|
OwningRewritePatternList &results, MLIRContext *context) {
|
||||||
|
results.insert<TensorCastToMemref>(context);
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// TransposeOp
|
// TransposeOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -131,3 +131,15 @@ func @fold_dim_of_tensor.cast(%arg0 : tensor<4x?xf32>) -> (index, index) {
|
||||||
%2 = dim %0, %c1 : tensor<?x?xf32>
|
%2 = dim %0, %c1 : tensor<?x?xf32>
|
||||||
return %1, %2: index, index
|
return %1, %2: index, index
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @tensor_cast_to_memref
|
||||||
|
// CHECK-SAME: %[[ARG0:.+]]: tensor<4x6x16x32xi8>
|
||||||
|
// CHECK: %[[M:.+]] = tensor_to_memref %[[ARG0]] : memref<4x6x16x32xi8>
|
||||||
|
// CHECK: %[[M1:.+]] = memref_cast %[[M]] : memref<4x6x16x32xi8> to memref<?x?x16x32xi8>
|
||||||
|
// CHECK: return %[[M1]] : memref<?x?x16x32xi8>
|
||||||
|
func @tensor_cast_to_memref(%arg0 : tensor<4x6x16x32xi8>) ->
|
||||||
|
memref<?x?x16x32xi8> {
|
||||||
|
%0 = tensor.cast %arg0 : tensor<4x6x16x32xi8> to tensor<?x?x16x32xi8>
|
||||||
|
%1 = tensor_to_memref %0 : memref<?x?x16x32xi8>
|
||||||
|
return %1 : memref<?x?x16x32xi8>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue