forked from OSchip/llvm-project
[mlir][tensor] Fix insert_slice + tensor cast overflow
InsertSliceOp may have subprefix semantics where missing trailing dimensions are automatically inferred directly from the operand shape. This revision fixes an overflow that occurs in such cases when the impl is based on the op rank. Differential Revision: https://reviews.llvm.org/D115549
This commit is contained in:
parent
9a3df8fbc2
commit
5601821dae
|
@ -1417,11 +1417,11 @@ struct InsertSliceOpSourceCastInserter final
|
|||
return failure();
|
||||
SmallVector<int64_t> newSrcShape(srcType.getShape().begin(),
|
||||
srcType.getShape().end());
|
||||
for (int64_t i = 0; i < srcType.getRank(); ++i) {
|
||||
if (Optional<int64_t> constInt =
|
||||
getConstantIntValue(insertSliceOp.getMixedSizes()[i]))
|
||||
newSrcShape[i] = *constInt;
|
||||
}
|
||||
// Offsets / sizes / strides can be a subprefix of the rank; take only the
|
||||
// leading dimensions.
|
||||
for (auto en : llvm::enumerate(insertSliceOp.getMixedSizes()))
|
||||
if (Optional<int64_t> constInt = getConstantIntValue(en.value()))
|
||||
newSrcShape[en.index()] = *constInt;
|
||||
|
||||
RankedTensorType newSrcType =
|
||||
RankedTensorType::get(newSrcShape, srcType.getElementType());
|
||||
|
|
|
@ -536,6 +536,21 @@ func @insert_tensor_cast_on_insert_slice_src(
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @insert_tensor_cast_on_insert_slice_src_prefix(
|
||||
// CHECK-SAME: %[[arg0:.*]]: tensor<?x5x?xf32>, %[[arg1:.*]]: tensor<?x?x?xf32>
|
||||
// CHECK: %[[cast:.*]] = tensor.cast %[[arg0]] : tensor<?x5x?xf32> to tensor<64x5x?xf32>
|
||||
// CHECK: %[[r:.*]] = tensor.insert_slice %[[cast]] into %[[arg1]][0, 1] [64, 5] [1, 1] : tensor<64x5x?xf32> into tensor<?x?x?xf32>
|
||||
// CHECK: return %[[r]]
|
||||
func @insert_tensor_cast_on_insert_slice_src_prefix(
|
||||
%arg0 : tensor<?x5x?xf32>, %arg1 : tensor<?x?x?xf32>, %sz0: index, %sz2: index) -> tensor<?x?x?xf32> {
|
||||
%c64 = arith.constant 64: index
|
||||
%r = tensor.insert_slice %arg0 into %arg1[0, 1] [%c64, 5] [1, 1]
|
||||
: tensor<?x5x?xf32> into tensor<?x?x?xf32>
|
||||
return %r : tensor<?x?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @fold_extract_insert
|
||||
// CHECK-SAME: %{{.+}}: tensor<?x?x?xf32>, %[[SLICE:.+]]: tensor<4x?x8xf32>
|
||||
func @fold_extract_insert(%input : tensor<?x?x?xf32>, %slice: tensor<4x?x8xf32>, %i: index, %size: index) -> (tensor<4x?x8xf32>) {
|
||||
|
|
Loading…
Reference in New Issue