forked from OSchip/llvm-project
[mlir][linalg] Fold linalg.pad_tensor if src type == result type
Fold PadTensorOp to source if source type and result type have static shape and are equal. Differential Revision: https://reviews.llvm.org/D103778
This commit is contained in:
parent
f5dc511c53
commit
b6ab4f1a8b
|
@ -296,6 +296,7 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
|
||||||
];
|
];
|
||||||
|
|
||||||
let hasCanonicalizer = 1;
|
let hasCanonicalizer = 1;
|
||||||
|
let hasFolder = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def Linalg_RangeOp :
|
def Linalg_RangeOp :
|
||||||
|
|
|
@ -1164,6 +1164,12 @@ Value PadTensorOp::getConstantPaddingValue() {
|
||||||
return padValue;
|
return padValue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
OpFoldResult PadTensorOp::fold(ArrayRef<Attribute>) {
|
||||||
|
if (getResultType().hasStaticShape() && getResultType() == getSourceType())
|
||||||
|
return source();
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// ReshapeOp
|
// ReshapeOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -893,6 +893,22 @@ func @dead_linalg_tensor(%arg0 : tensor<7x7xi32>, %arg1 : tensor<7x7xf32>,
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @pad_tensor_same_static_shape(
|
||||||
|
// CHECK-SAME: %[[ARG0:.*]]: tensor<5x6xf32>
|
||||||
|
// CHECK-NOT: linalg.pad_tensor
|
||||||
|
// CHECK: return %[[ARG0]]
|
||||||
|
func @pad_tensor_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
|
||||||
|
-> tensor<5x6xf32> {
|
||||||
|
%cst = constant 0.000000e+00 : f32
|
||||||
|
%0 = linalg.pad_tensor %arg0 low[%a, 0] high[0, %a] {
|
||||||
|
^bb0(%arg1: index, %arg2: index):
|
||||||
|
linalg.yield %cst : f32
|
||||||
|
} : tensor<5x6xf32> to tensor<5x6xf32>
|
||||||
|
return %0 : tensor<5x6xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
func @dim_reshape_expansion(%arg0 : tensor<6x5x?xf32>) -> (index, index, index)
|
func @dim_reshape_expansion(%arg0 : tensor<6x5x?xf32>) -> (index, index, index)
|
||||||
{
|
{
|
||||||
%c1 = constant 1 : index
|
%c1 = constant 1 : index
|
||||||
|
|
Loading…
Reference in New Issue