[mlir][tensor] Fix insert_slice + tensor cast overflow

InsertSliceOp may have subprefix semantics where missing trailing dimensions
are automatically inferred directly from the operand shape.
This revision fixes an overflow that occurs in such cases when the impl is based on the op rank.

Differential Revision: https://reviews.llvm.org/D115549
This commit is contained in:
Nicolas Vasilache 2021-12-10 21:27:20 +00:00
parent 9a3df8fbc2
commit 5601821dae
2 changed files with 20 additions and 5 deletions

View File

@ -1417,11 +1417,11 @@ struct InsertSliceOpSourceCastInserter final
return failure();
SmallVector<int64_t> newSrcShape(srcType.getShape().begin(),
srcType.getShape().end());
for (int64_t i = 0; i < srcType.getRank(); ++i) {
if (Optional<int64_t> constInt =
getConstantIntValue(insertSliceOp.getMixedSizes()[i]))
newSrcShape[i] = *constInt;
}
// Offsets / sizes / strides can be a subprefix of the rank; take only the
// leading dimensions.
for (auto en : llvm::enumerate(insertSliceOp.getMixedSizes()))
if (Optional<int64_t> constInt = getConstantIntValue(en.value()))
newSrcShape[en.index()] = *constInt;
RankedTensorType newSrcType =
RankedTensorType::get(newSrcShape, srcType.getElementType());

View File

@ -536,6 +536,21 @@ func @insert_tensor_cast_on_insert_slice_src(
// -----
// CHECK-LABEL: func @insert_tensor_cast_on_insert_slice_src_prefix(
// CHECK-SAME: %[[arg0:.*]]: tensor<?x5x?xf32>, %[[arg1:.*]]: tensor<?x?x?xf32>
// CHECK: %[[cast:.*]] = tensor.cast %[[arg0]] : tensor<?x5x?xf32> to tensor<64x5x?xf32>
// CHECK: %[[r:.*]] = tensor.insert_slice %[[cast]] into %[[arg1]][0, 1] [64, 5] [1, 1] : tensor<64x5x?xf32> into tensor<?x?x?xf32>
// CHECK: return %[[r]]
func @insert_tensor_cast_on_insert_slice_src_prefix(
%arg0 : tensor<?x5x?xf32>, %arg1 : tensor<?x?x?xf32>, %sz0: index, %sz2: index) -> tensor<?x?x?xf32> {
%c64 = arith.constant 64: index
%r = tensor.insert_slice %arg0 into %arg1[0, 1] [%c64, 5] [1, 1]
: tensor<?x5x?xf32> into tensor<?x?x?xf32>
return %r : tensor<?x?x?xf32>
}
// -----
// CHECK-LABEL: func @fold_extract_insert
// CHECK-SAME: %{{.+}}: tensor<?x?x?xf32>, %[[SLICE:.+]]: tensor<4x?x8xf32>
func @fold_extract_insert(%input : tensor<?x?x?xf32>, %slice: tensor<4x?x8xf32>, %i: index, %size: index) -> (tensor<4x?x8xf32>) {