forked from OSchip/llvm-project
Relax FuseTensorReshapeOpAsproducer identity mapping constraint
Differential Revision: https://reviews.llvm.org/D88869
This commit is contained in:
parent
5e4409f308
commit
7060920bd1
|
@ -326,7 +326,7 @@ static bool isTensorReshapeOpFusible(TensorReshapeOp reshapeOp,
|
|||
if ((asProducer && returnType.getRank() < operandType.getRank()) ||
|
||||
(!asProducer && operandType.getRank() < returnType.getRank()))
|
||||
return false;
|
||||
return useIndexMap.isIdentity();
|
||||
return useIndexMap.isPermutation();
|
||||
}
|
||||
|
||||
/// Based on the type of `op` create a linalg op of the same type, i.e. if `op`
|
||||
|
@ -381,10 +381,13 @@ struct FuseTensorReshapeOpAsProducer {
|
|||
return attr.cast<AffineMapAttr>().getValue();
|
||||
}));
|
||||
|
||||
// Accepted consumer maps are either identity or permutation.
|
||||
auto invMap = inversePermutation(fusedIndexMaps[consumerIdx]);
|
||||
|
||||
// Compute the indexing map to use for the operand of the producer.
|
||||
AffineMap modifiedMap = linearizeCollapsedDims(
|
||||
fusedIndexMaps[consumerIdx], producer.getResultType().getShape(),
|
||||
producer.getReassociationMaps());
|
||||
AffineMap modifiedMap =
|
||||
linearizeCollapsedDims(invMap, producer.getResultType().getShape(),
|
||||
producer.getReassociationMaps());
|
||||
for (AffineExpr expr : modifiedMap.getResults()) {
|
||||
if (!expr.isPureAffine())
|
||||
return nullptr;
|
||||
|
@ -439,10 +442,13 @@ struct FuseTensorReshapeOpAsConsumer {
|
|||
producer.indexing_maps(), [](Attribute attr) -> AffineMap {
|
||||
return attr.cast<AffineMapAttr>().getValue();
|
||||
}));
|
||||
|
||||
auto invMap = inversePermutation(producer.getOutputIndexingMap(0));
|
||||
|
||||
// Compute the indexing map to use for the operand of the producer.
|
||||
AffineMap modifiedMap = linearizeCollapsedDims(
|
||||
producer.getOutputIndexingMap(0), consumer.getSrcType().getShape(),
|
||||
consumer.getReassociationMaps());
|
||||
AffineMap modifiedMap =
|
||||
linearizeCollapsedDims(invMap, consumer.getSrcType().getShape(),
|
||||
consumer.getReassociationMaps());
|
||||
for (AffineExpr expr : modifiedMap.getResults()) {
|
||||
if (!expr.isPureAffine())
|
||||
return nullptr;
|
||||
|
|
|
@ -558,3 +558,100 @@ func @indexed_generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?x4x5xi32>)
|
|||
// CHECK: linalg.indexed_generic
|
||||
// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
|
||||
// CHECK-NOT: linalg.tensor_reshape
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1 + d2 * 7)>
|
||||
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
|
||||
#map0 = affine_map<(d0, d1, d2) -> (d0)>
|
||||
#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
|
||||
#map2 = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
|
||||
#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
func @generic_op_021_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<3x7x5xf32> {
|
||||
%0 = linalg.tensor_reshape %arg0 [#map0, #map1] : tensor<3x35xf32> into tensor<3x5x7xf32>
|
||||
%1 = linalg.generic {indexing_maps = [#map2, #map3], iterator_types = ["parallel", "parallel", "parallel"]} ins(%0 : tensor<3x5x7xf32>) {
|
||||
^bb0(%arg2: f32): // no predecessors
|
||||
linalg.yield %arg2 : f32
|
||||
} -> tensor<3x7x5xf32>
|
||||
return %1 : tensor<3x7x5xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @generic_op_021_permultation_reshape_producer_fusion
|
||||
// CHECK-NOT: linalg.tensor_reshape
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
|
||||
// CHECK-NOT: linalg.tensor_reshape
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0 * 7 + d1)>
|
||||
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
|
||||
#map0 = affine_map<(d0, d1, d2) -> (d0)>
|
||||
#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
|
||||
#map2 = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
|
||||
#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
func @generic_op_120_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<5x7x3xf32> {
|
||||
%0 = linalg.tensor_reshape %arg0 [#map0, #map1] : tensor<3x35xf32> into tensor<3x5x7xf32>
|
||||
%1 = linalg.generic {indexing_maps = [#map2, #map3], iterator_types = ["parallel", "parallel", "parallel"]} ins(%0 : tensor<3x5x7xf32>) {
|
||||
^bb0(%arg2: f32): // no predecessors
|
||||
linalg.yield %arg2 : f32
|
||||
} -> tensor<5x7x3xf32>
|
||||
return %1 : tensor<5x7x3xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @generic_op_120_permultation_reshape_producer_fusion
|
||||
// CHECK-NOT: linalg.tensor_reshape
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
|
||||
// CHECK-NOT: linalg.tensor_reshape
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d1, d0 * 7 + d2)>
|
||||
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
|
||||
#map0 = affine_map<(d0, d1, d2) -> (d0)>
|
||||
#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
|
||||
#map2 = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
|
||||
#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
func @generic_op_102_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<5x3x7xf32> {
|
||||
%0 = linalg.tensor_reshape %arg0 [#map0, #map1] : tensor<3x35xf32> into tensor<3x5x7xf32>
|
||||
%1 = linalg.generic {indexing_maps = [#map2, #map3], iterator_types = ["parallel", "parallel", "parallel"]} ins(%0 : tensor<3x5x7xf32>) {
|
||||
^bb0(%arg2: f32): // no predecessors
|
||||
linalg.yield %arg2 : f32
|
||||
} -> tensor<5x3x7xf32>
|
||||
return %1 : tensor<5x3x7xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @generic_op_102_permultation_reshape_producer_fusion
|
||||
// CHECK-NOT: linalg.tensor_reshape
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
|
||||
// CHECK-NOT: linalg.tensor_reshape
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d1, d0 * 7 + d2)>
|
||||
|
||||
|
||||
#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
#map1 = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
|
||||
#map2 = affine_map<(d0, d1, d2) -> (d0)>
|
||||
#map3 = affine_map<(d0, d1, d2) -> (d1, d2)>
|
||||
func @generic_op_102_permultation_reshape_consumer_fusion(%arg0 : tensor<3x5x7xf32>) -> tensor<5x21xf32> {
|
||||
%0 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : tensor<3x5x7xf32>) {
|
||||
^bb0(%arg2: f32): // no predecessors
|
||||
linalg.yield %arg2 : f32
|
||||
} -> tensor<5x3x7xf32>
|
||||
%1 = linalg.tensor_reshape %0 [#map2, #map3] : tensor<5x3x7xf32> into tensor<5x21xf32>
|
||||
return %1 : tensor<5x21xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @generic_op_102_permultation_reshape_consumer_fusion
|
||||
// CHECK-NOT: linalg.tensor_reshape
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
|
||||
// CHECK-NOT: linalg.tensor_reshape
|
||||
|
|
Loading…
Reference in New Issue