[mlir][linalg] Fix result type in FoldSourceTensorCast

* Do not discard static result type information that cannot be inferred from lower/upper padding.
* Add optional argument to `PadTensorOp::inferResultType` for specifying known result dimensions.

Differential Revision: https://reviews.llvm.org/D110380
This commit is contained in:
Matthias Springer 2021-09-24 16:39:37 +09:00
parent 03142c5f67
commit f3f25ffc04
3 changed files with 43 additions and 11 deletions

View File

@ -226,10 +226,14 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
}
// Infer the shape of the result tensor given the type of the source tensor
// and paddings.
static RankedTensorType inferResultType(RankedTensorType sourceType,
// and paddings. Known result dimensions that cannot necessarily be inferred
// from low/high padding sizes can be optionally specified. Those will be
// considered when computing the result type.
static RankedTensorType inferResultType(
RankedTensorType sourceType,
ArrayRef<int64_t> staticLow,
ArrayRef<int64_t> staticHigh);
ArrayRef<int64_t> staticHigh,
ArrayRef<int64_t> resultShape = {});
// Return a PadTensorOp that pads `source` to `type` size where the static
// sizes are assumed to be greater than the dynamic sizes. The op performs

View File

@ -1055,24 +1055,31 @@ static LogicalResult verify(PadTensorOp op) {
RankedTensorType PadTensorOp::inferResultType(RankedTensorType sourceType,
ArrayRef<int64_t> staticLow,
ArrayRef<int64_t> staticHigh) {
ArrayRef<int64_t> staticHigh,
ArrayRef<int64_t> resultShape) {
unsigned rank = sourceType.getRank();
assert(staticLow.size() == rank && "unexpected staticLow size mismatch");
assert(staticHigh.size() == rank && "unexpected staticHigh size mismatch");
assert((resultShape.empty() || resultShape.size() == rank) &&
"unexpected resultShape size mismatch");
SmallVector<int64_t, 4> resultShape;
SmallVector<int64_t, 4> inferredShape;
for (auto i : llvm::seq<unsigned>(0, rank)) {
if (sourceType.isDynamicDim(i) ||
staticLow[i] == ShapedType::kDynamicSize ||
staticHigh[i] == ShapedType::kDynamicSize) {
resultShape.push_back(ShapedType::kDynamicSize);
inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamicSize
: resultShape[i]);
} else {
int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];
resultShape.push_back(size);
assert((resultShape.empty() || size == resultShape[i] ||
resultShape[i] == ShapedType::kDynamicSize) &&
"mismatch between inferred shape and result shape");
inferredShape.push_back(size);
}
}
return RankedTensorType::get(resultShape, sourceType.getElementType());
return RankedTensorType::get(inferredShape, sourceType.getElementType());
}
void PadTensorOp::build(OpBuilder &b, OperationState &result, Value source,
@ -1454,7 +1461,8 @@ struct FoldSourceTensorCast : public OpRewritePattern<PadTensorOp> {
auto newResultType = PadTensorOp::inferResultType(
castOp.source().getType().cast<RankedTensorType>(),
extractFromI64ArrayAttr(padTensorOp.static_low()),
extractFromI64ArrayAttr(padTensorOp.static_high()));
extractFromI64ArrayAttr(padTensorOp.static_high()),
padTensorOp.getResultType().getShape());
if (newResultType == padTensorOp.getResultType()) {
rewriter.updateRootInPlace(padTensorOp, [&]() {

View File

@ -629,7 +629,8 @@ func @pad_tensor_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
}
// -----
// CHECK-LABEL: func @pad_tensor_after_cast_differnt_shape(
// CHECK-LABEL: func @pad_tensor_after_cast_different_shape(
// CHECK-SAME: %[[INPUT:.*]]: tensor<?x64x?x?xf32>) -> tensor<?x?x?x?xf32> {
// CHECK: %[[CST:.*]] = constant 0.000000e+00 : f32
// CHECK: %[[PADDED:.*]] = linalg.pad_tensor %[[INPUT]]
@ -641,7 +642,7 @@ func @pad_tensor_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
// CHECK-SAME: tensor<?x64x?x?xf32> to tensor<?x?x?x?xf32>
// CHECK: return %[[DYNAMIC]] : tensor<?x?x?x?xf32>
// CHECK: }
func @pad_tensor_after_cast_differnt_shape(%arg0: tensor<?x64x?x?xf32>)
func @pad_tensor_after_cast_different_shape(%arg0: tensor<?x64x?x?xf32>)
-> tensor<?x?x?x?xf32> {
%cst = constant 0.000000e+00 : f32
%dynamic = tensor.cast %arg0 : tensor<?x64x?x?xf32> to tensor<?x?x?x?xf32>
@ -653,6 +654,7 @@ func @pad_tensor_after_cast_differnt_shape(%arg0: tensor<?x64x?x?xf32>)
}
// -----
// CHECK-LABEL: func @pad_tensor_after_cast_same_shape(
// CHECK-SAME: %[[INPUT:.*]]: tensor<?x64x?x?xf32>,
// CHECK-SAME: %[[PADDING:.*]]: index) -> tensor<?x?x?x?xf32> {
@ -676,6 +678,24 @@ func @pad_tensor_after_cast_same_shape(%arg0: tensor<?x64x?x?xf32>, %padding : i
}
// -----
// CHECK-LABEL: func @pad_tensor_of_cast(
// CHECK-NOT: tensor.cast
// CHECK: linalg.pad_tensor
// CHECK: tensor<8x?xf32> to tensor<8x32xf32>
func @pad_tensor_of_cast(%t: tensor<8x?xf32>, %s: index) -> tensor<8x32xf32> {
%c0 = constant 0 : index
%cst = constant 0.000000e+00 : f32
%0 = tensor.cast %t : tensor<8x?xf32> to tensor<?x?xf32>
%1 = linalg.pad_tensor %0 low[%c0, %c0] high[%c0, %s] {
^bb0(%arg9: index, %arg10: index): // no predecessors
linalg.yield %cst : f32
} : tensor<?x?xf32> to tensor<8x32xf32>
return %1 : tensor<8x32xf32>
}
// -----
func @propogate_casts(%arg0 : tensor<?x?xf32>, %arg1 : f32, %arg2 : index,
%arg3 : index) -> tensor<?x?xf32> {
%c0 = constant 0 : index