forked from OSchip/llvm-project
Fix CollapsedLayoutMap for dim size 1 case
This change fixes `CollapsedLayoutMap` for cases where the collapsed dims are size 1. The cases where inner most dims are size 1 and noncontiguous can be represented by the strided form and therefore can be allowed. For such cases, the new stride should be of the next entry in an association whose dimension is not size 1. If the next entry is dynamic, it's not possible to decide which stride to use at compilation time and the stride is set to dynamic. Differential Revision:
This commit is contained in:
@ -1824,12 +1824,27 @@ computeCollapsedLayoutMap(MemRefType srcType,
if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
return failure();
// The result strides are exactly the strides of the last entry of each
// reassociation.
// The result stride of a reassociation group is the stride of the last entry
// of the reassociation. (TODO: Should be the minimum stride in the
// reassociation because strides are not necessarily sorted. E.g., when using
// memref.transpose.) Dimensions of size 1 should be skipped, because their
// strides are meaningless and could have any arbitrary value.
SmallVector<int64_t> resultStrides;
for (ReassociationIndices reassoc : reassociation)
for (const ReassociationIndices &reassoc : reassociation) {
ArrayRef<int64_t> ref = llvm::makeArrayRef(reassoc);
while (srcShape[ref.back()] == 1 && ref.size() > 1)
ref = ref.drop_back();
if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) {
} else {
// Dynamically-sized dims may turn out to be dims of size 1 at runtime, so
// the corresponding stride may have to be skipped. (See above comment.)
// Therefore, the result stride cannot be statically determined and must
// be dynamic.
// Validate that each reassociation group is contiguous.
unsigned resultStrideIndex = resultStrides.size() - 1;
@ -331,14 +331,14 @@ func.func @compose_collapse_of_collapse(%arg0 : memref<?x?x?x?x?xf32>)
func.func @do_not_compose_collapse_of_expand_non_identity_layout(
%arg0: memref<?x?xf32, offset : 0, strides : [?, 1]>)
-> memref<?xf32> {
-> memref<?xf32, offset : 0, strides : [?]> {
%1 = memref.expand_shape %arg0 [[0, 1], [2]] :
memref<?x?xf32, offset : 0, strides : [?, 1]> into
memref<?x4x?xf32, offset : 0, strides : [?, ?, 1]>
%2 = memref.collapse_shape %1 [[0, 1, 2]] :
memref<?x4x?xf32, offset : 0, strides : [?, ?, 1]> into
return %2 : memref<?xf32>
memref<?xf32, offset : 0, strides : [?]>
return %2 : memref<?xf32, offset : 0, strides : [?]>
// CHECK-LABEL: func @do_not_compose_collapse_of_expand_non_identity_layout
// CHECK: expand
@ -1,11 +1,14 @@
// RUN: mlir-opt %s -tensor-bufferize -cse | FileCheck %s
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1)[s0] -> (d0 * 20 + s0 + d1)>
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 140 + d1 * 20 + d2 * 5 + d3 + s0)>
// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0) -> (d0 + 1)>
// CHECK-DAG: #[[$MAP4:.*]] = affine_map<() -> (1)>
// CHECK-DAG: #[[$MAP5:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1)>
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1)[s0] -> (d0 * 20 + s0 + d1)>
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 140 + d1 * 20 + d2 * 5 + d3 + s0)>
// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0) -> (d0 + 1)>
// CHECK-DAG: #[[$MAP4:.*]] = affine_map<() -> (1)>
// CHECK-DAG: #[[$MAP5:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1)>
// CHECK-DAG: #[[$MAP6:.*]] = affine_map<(d0) -> (d0 * 2)>
// CHECK-DAG: #[[$MAP7:.*]] = affine_map<(d0, d1, d2)[s0] -> (d0 * 8 + s0 + d1 * 4 + d2)>
// CHECK-DAG: #[[$MAP8:.*]] = affine_map<(d0)[s0] -> (d0 * 4 + s0)>
// CHECK-LABEL: func @dim(
// CHECK-SAME: %[[TENSOR:.*]]: tensor<f32>,
@ -330,17 +333,6 @@ func.func @tensor.expand_shape_of_slice(
return %1 : tensor<?x7x2x5xf32>
// CHECK-LABEL: func @tensor.expand_shape_of_slice2(
// CHECK-SAME: %[[t1:.*]]: tensor<1x2xf32>
func.func @tensor.expand_shape_of_slice2(%t1: tensor<1x2xf32>) -> tensor<1xf32> {
// CHECK: memref.subview {{.*}} : memref<1x2xf32> to memref<1x1xf32, #[[$MAP5]]>
%0 = tensor.extract_slice %t1[0, 0][1, 1][1, 1] : tensor<1x2xf32> to tensor<1x1xf32>
// CHECK: memref.collapse_shape %{{.*}} [
// CHECK-SAME: [0, 1]] : memref<1x1xf32, #[[$MAP5]]> into memref<1xf32>
%1 = tensor.collapse_shape %0 [[0, 1]] : tensor<1x1xf32> into tensor<1xf32>
return %1 : tensor<1xf32>
// CHECK-LABEL: func @tensor.collapse_shape(
// CHECK-SAME: %[[t1:.*]]: tensor<2x?x?xf32>
func.func @tensor.collapse_shape(%t1: tensor<2x?x?xf32>) -> tensor<?x?xf32> {
@ -393,3 +385,26 @@ func.func @tensor.collapse_shape_of_slice2(
%1 = tensor.collapse_shape %0 [[0], [1, 2, 3]] : tensor<87x78x68x12xi64> into tensor<87x63648xi64>
return %1 : tensor<87x63648xi64>
// CHECK-LABEL: func @tensor.collapse_shape_of_slice3(
// CHECK-SAME: %[[t1:.*]]: tensor<1x2xf32>
func.func @tensor.collapse_shape_of_slice3(%t1: tensor<1x2xf32>) -> tensor<1xf32> {
// CHECK: memref.subview {{.*}} : memref<1x2xf32> to memref<1x1xf32, #[[$MAP5]]>
%0 = tensor.extract_slice %t1[0, 0][1, 1][1, 1] : tensor<1x2xf32> to tensor<1x1xf32>
// CHECK: memref.collapse_shape %{{.*}} [
// CHECK-SAME: [0, 1]] : memref<1x1xf32, #[[$MAP5]]> into memref<1xf32, #[[$MAP6]]>
%1 = tensor.collapse_shape %0 [[0, 1]] : tensor<1x1xf32> into tensor<1xf32>
return %1 : tensor<1xf32>
// CHECK-LABEL: func @tensor.collapse_shape_of_slice4(
// CHECK-SAME: %[[t1:.*]]: tensor<?x2x4xf32>,
// CHECK-SAME: %[[OFFSET:.*]]: index) -> tensor<8xf32> {
func.func @tensor.collapse_shape_of_slice4(%arg0: tensor<?x2x4xf32>, %offset: index, %size: index) -> tensor<8xf32> {
// CHECK: memref.subview %{{.*}} : memref<?x2x4xf32> to memref<4x2x1xf32, #[[$MAP7]]>
%0 = tensor.extract_slice %arg0[0, 0, %offset] [4, 2, 1] [1, 1, 1] : tensor<?x2x4xf32> to tensor<4x2x1xf32>
// CHECK: memref.collapse_shape %{{.*}} [
// CHECK-SAME: [0, 1, 2]] : memref<4x2x1xf32, #[[$MAP7]]> into memref<8xf32, #[[$MAP8]]>
%ret = tensor.collapse_shape %0 [[0, 1, 2]] : tensor<4x2x1xf32> into tensor<8xf32>
return %ret: tensor<8xf32>
Reference in New Issue