diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index edddfb86e553..49cfec663ef7 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -1417,11 +1417,11 @@ struct InsertSliceOpSourceCastInserter final return failure(); SmallVector newSrcShape(srcType.getShape().begin(), srcType.getShape().end()); - for (int64_t i = 0; i < srcType.getRank(); ++i) { - if (Optional constInt = - getConstantIntValue(insertSliceOp.getMixedSizes()[i])) - newSrcShape[i] = *constInt; - } + // Offsets / sizes / strides can be a subprefix of the rank; take only the + // leading dimensions. + for (auto en : llvm::enumerate(insertSliceOp.getMixedSizes())) + if (Optional constInt = getConstantIntValue(en.value())) + newSrcShape[en.index()] = *constInt; RankedTensorType newSrcType = RankedTensorType::get(newSrcShape, srcType.getElementType()); diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index fc9abe439b8a..50fda25cce26 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -536,6 +536,21 @@ func @insert_tensor_cast_on_insert_slice_src( // ----- +// CHECK-LABEL: func @insert_tensor_cast_on_insert_slice_src_prefix( +// CHECK-SAME: %[[arg0:.*]]: tensor, %[[arg1:.*]]: tensor +// CHECK: %[[cast:.*]] = tensor.cast %[[arg0]] : tensor to tensor<64x5x?xf32> +// CHECK: %[[r:.*]] = tensor.insert_slice %[[cast]] into %[[arg1]][0, 1] [64, 5] [1, 1] : tensor<64x5x?xf32> into tensor +// CHECK: return %[[r]] +func @insert_tensor_cast_on_insert_slice_src_prefix( + %arg0 : tensor, %arg1 : tensor, %sz0: index, %sz2: index) -> tensor { + %c64 = arith.constant 64: index + %r = tensor.insert_slice %arg0 into %arg1[0, 1] [%c64, 5] [1, 1] + : tensor into tensor + return %r : tensor +} + +// ----- + // CHECK-LABEL: func @fold_extract_insert // CHECK-SAME: %{{.+}}: tensor, %[[SLICE:.+]]: tensor<4x?x8xf32> func @fold_extract_insert(%input : tensor, %slice: tensor<4x?x8xf32>, %i: index, %size: index) -> (tensor<4x?x8xf32>) {