forked from OSchip/llvm-project
[mlir][linalg] Fix result type in FoldSourceTensorCast
* Do not discard static result type information that cannot be inferred from lower/upper padding. * Add optional argument to `PadTensorOp::inferResultType` for specifying known result dimensions. Differential Revision: https://reviews.llvm.org/D110380
This commit is contained in:
parent
03142c5f67
commit
f3f25ffc04
|
@ -226,10 +226,14 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
|
|||
}
|
||||
|
||||
// Infer the shape of the result tensor given the type of the source tensor
|
||||
// and paddings.
|
||||
static RankedTensorType inferResultType(RankedTensorType sourceType,
|
||||
// and paddings. Known result dimensions that cannot necessarily be inferred
|
||||
// from low/high padding sizes can be optionally specified. Those will be
|
||||
// considered when computing the result type.
|
||||
static RankedTensorType inferResultType(
|
||||
RankedTensorType sourceType,
|
||||
ArrayRef<int64_t> staticLow,
|
||||
ArrayRef<int64_t> staticHigh);
|
||||
ArrayRef<int64_t> staticHigh,
|
||||
ArrayRef<int64_t> resultShape = {});
|
||||
|
||||
// Return a PadTensorOp that pads `source` to `type` size where the static
|
||||
// sizes are assumed to be greater than the dynamic sizes. The op performs
|
||||
|
|
|
@ -1055,24 +1055,31 @@ static LogicalResult verify(PadTensorOp op) {
|
|||
|
||||
RankedTensorType PadTensorOp::inferResultType(RankedTensorType sourceType,
|
||||
ArrayRef<int64_t> staticLow,
|
||||
ArrayRef<int64_t> staticHigh) {
|
||||
ArrayRef<int64_t> staticHigh,
|
||||
ArrayRef<int64_t> resultShape) {
|
||||
unsigned rank = sourceType.getRank();
|
||||
assert(staticLow.size() == rank && "unexpected staticLow size mismatch");
|
||||
assert(staticHigh.size() == rank && "unexpected staticHigh size mismatch");
|
||||
assert((resultShape.empty() || resultShape.size() == rank) &&
|
||||
"unexpected resultShape size mismatch");
|
||||
|
||||
SmallVector<int64_t, 4> resultShape;
|
||||
SmallVector<int64_t, 4> inferredShape;
|
||||
for (auto i : llvm::seq<unsigned>(0, rank)) {
|
||||
if (sourceType.isDynamicDim(i) ||
|
||||
staticLow[i] == ShapedType::kDynamicSize ||
|
||||
staticHigh[i] == ShapedType::kDynamicSize) {
|
||||
resultShape.push_back(ShapedType::kDynamicSize);
|
||||
inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamicSize
|
||||
: resultShape[i]);
|
||||
} else {
|
||||
int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];
|
||||
resultShape.push_back(size);
|
||||
assert((resultShape.empty() || size == resultShape[i] ||
|
||||
resultShape[i] == ShapedType::kDynamicSize) &&
|
||||
"mismatch between inferred shape and result shape");
|
||||
inferredShape.push_back(size);
|
||||
}
|
||||
}
|
||||
|
||||
return RankedTensorType::get(resultShape, sourceType.getElementType());
|
||||
return RankedTensorType::get(inferredShape, sourceType.getElementType());
|
||||
}
|
||||
|
||||
void PadTensorOp::build(OpBuilder &b, OperationState &result, Value source,
|
||||
|
@ -1454,7 +1461,8 @@ struct FoldSourceTensorCast : public OpRewritePattern<PadTensorOp> {
|
|||
auto newResultType = PadTensorOp::inferResultType(
|
||||
castOp.source().getType().cast<RankedTensorType>(),
|
||||
extractFromI64ArrayAttr(padTensorOp.static_low()),
|
||||
extractFromI64ArrayAttr(padTensorOp.static_high()));
|
||||
extractFromI64ArrayAttr(padTensorOp.static_high()),
|
||||
padTensorOp.getResultType().getShape());
|
||||
|
||||
if (newResultType == padTensorOp.getResultType()) {
|
||||
rewriter.updateRootInPlace(padTensorOp, [&]() {
|
||||
|
|
|
@ -629,7 +629,8 @@ func @pad_tensor_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
|
|||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @pad_tensor_after_cast_differnt_shape(
|
||||
|
||||
// CHECK-LABEL: func @pad_tensor_after_cast_different_shape(
|
||||
// CHECK-SAME: %[[INPUT:.*]]: tensor<?x64x?x?xf32>) -> tensor<?x?x?x?xf32> {
|
||||
// CHECK: %[[CST:.*]] = constant 0.000000e+00 : f32
|
||||
// CHECK: %[[PADDED:.*]] = linalg.pad_tensor %[[INPUT]]
|
||||
|
@ -641,7 +642,7 @@ func @pad_tensor_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
|
|||
// CHECK-SAME: tensor<?x64x?x?xf32> to tensor<?x?x?x?xf32>
|
||||
// CHECK: return %[[DYNAMIC]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: }
|
||||
func @pad_tensor_after_cast_differnt_shape(%arg0: tensor<?x64x?x?xf32>)
|
||||
func @pad_tensor_after_cast_different_shape(%arg0: tensor<?x64x?x?xf32>)
|
||||
-> tensor<?x?x?x?xf32> {
|
||||
%cst = constant 0.000000e+00 : f32
|
||||
%dynamic = tensor.cast %arg0 : tensor<?x64x?x?xf32> to tensor<?x?x?x?xf32>
|
||||
|
@ -653,6 +654,7 @@ func @pad_tensor_after_cast_differnt_shape(%arg0: tensor<?x64x?x?xf32>)
|
|||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @pad_tensor_after_cast_same_shape(
|
||||
// CHECK-SAME: %[[INPUT:.*]]: tensor<?x64x?x?xf32>,
|
||||
// CHECK-SAME: %[[PADDING:.*]]: index) -> tensor<?x?x?x?xf32> {
|
||||
|
@ -676,6 +678,24 @@ func @pad_tensor_after_cast_same_shape(%arg0: tensor<?x64x?x?xf32>, %padding : i
|
|||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @pad_tensor_of_cast(
|
||||
// CHECK-NOT: tensor.cast
|
||||
// CHECK: linalg.pad_tensor
|
||||
// CHECK: tensor<8x?xf32> to tensor<8x32xf32>
|
||||
func @pad_tensor_of_cast(%t: tensor<8x?xf32>, %s: index) -> tensor<8x32xf32> {
|
||||
%c0 = constant 0 : index
|
||||
%cst = constant 0.000000e+00 : f32
|
||||
%0 = tensor.cast %t : tensor<8x?xf32> to tensor<?x?xf32>
|
||||
%1 = linalg.pad_tensor %0 low[%c0, %c0] high[%c0, %s] {
|
||||
^bb0(%arg9: index, %arg10: index): // no predecessors
|
||||
linalg.yield %cst : f32
|
||||
} : tensor<?x?xf32> to tensor<8x32xf32>
|
||||
return %1 : tensor<8x32xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @propogate_casts(%arg0 : tensor<?x?xf32>, %arg1 : f32, %arg2 : index,
|
||||
%arg3 : index) -> tensor<?x?xf32> {
|
||||
%c0 = constant 0 : index
|
||||
|
|
Loading…
Reference in New Issue