[mlir][MemRef] Fix a crash when expanding a scalar shape

In this case the reassociation is empty, yielding no strides for the
result type.

Differential Revision: https://reviews.llvm.org/D127232
This commit is contained in:
Benjamin Kramer 2022-06-07 19:30:10 +02:00
parent c6d6535a26
commit 6eb0f8e285
2 changed files with 20 additions and 3 deletions

View File

@ -1734,9 +1734,10 @@ computeExpandedLayoutMap(MemRefType srcType, ArrayRef<int64_t> resultShape,
.asStride();
}
}
return makeStridedLinearLayoutMap(
llvm::to_vector<8>(llvm::reverse(reverseResultStrides)), srcOffset,
srcType.getContext());
auto resultStrides = llvm::to_vector<8>(llvm::reverse(reverseResultStrides));
resultStrides.resize(resultShape.size(), 1);
return makeStridedLinearLayoutMap(resultStrides, srcOffset,
srcType.getContext());
}
static FailureOr<MemRefType>

View File

@ -9,6 +9,8 @@
// CHECK-DAG: #[[$MAP6:.*]] = affine_map<(d0) -> (d0 * 2)>
// CHECK-DAG: #[[$MAP7:.*]] = affine_map<(d0, d1, d2)[s0] -> (d0 * 8 + s0 + d1 * 4 + d2)>
// CHECK-DAG: #[[$MAP8:.*]] = affine_map<(d0)[s0] -> (d0 * 4 + s0)>
// CHECK-DAG: #[[$MAP9:.*]] = affine_map<()[s0] -> (s0)>
// CHECK-DAG: #[[$MAP10:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
// CHECK-LABEL: func @dim(
// CHECK-SAME: %[[TENSOR:.*]]: tensor<f32>,
@ -333,6 +335,20 @@ func.func @tensor.expand_shape_of_slice(
return %1 : tensor<?x7x2x5xf32>
}
// CHECK-LABEL: func @tensor.expand_shape_of_scalar_slice(
// CHECK-SAME: %[[t1:.*]]: tensor<?xf32>
func.func @tensor.expand_shape_of_scalar_slice(
%t1: tensor<?xf32>, %o1: index, %s1: index) -> tensor<1xf32> {
// CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<?xf32>
// CHECK: %[[subview:.*]] = memref.subview %[[m1]][%{{.*}}] [1] [1] : memref<?xf32> to memref<f32, #[[$MAP9]]>
%0 = tensor.extract_slice %t1[%o1][1][1] : tensor<?xf32> to tensor<f32>
// CHECK: %[[expanded:.*]] = memref.expand_shape %[[subview]] [] : memref<f32, #[[$MAP9]]> into memref<1xf32, #[[$MAP10]]>
%1 = tensor.expand_shape %0 [] : tensor<f32> into tensor<1xf32>
// CHECK: %[[r:.*]] = bufferization.to_tensor %[[expanded]]
// CHECK: return %[[r]]
return %1 : tensor<1xf32>
}
// CHECK-LABEL: func @tensor.collapse_shape(
// CHECK-SAME: %[[t1:.*]]: tensor<2x?x?xf32>
func.func @tensor.collapse_shape(%t1: tensor<2x?x?xf32>) -> tensor<?x?xf32> {