[mlir][Linalg] Disable fusion of reshape with elementwise ops for purely dynamic cases.

`tensor.collapse_shape` op when fused with a consumer elementwise
`linalg.generic` operation results in creation of tensor.expand_shape
ops. In purely dynamic cases this can end up with a dynamic dimensions
being expanded to more than one dynamic dimension. This is disallowed
by the semantics of `tensor.expand_shape` operation. (While the
transformation is itself correct, its a gap in the specification of
`tensor.expand_shape` that is the issue). So disallow fusions which
result in such a pattern.

Differential Revision: https://reviews.llvm.org/D116703
This commit is contained in:
MaheshRavishankar 2022-01-06 10:31:53 -08:00
parent 0b5340acb7
commit 4317a3dfad
2 changed files with 59 additions and 2 deletions

View File

@ -524,6 +524,7 @@ public:
LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand,
ArrayRef<AffineMap> reassociationMaps,
ArrayRef<int64_t> expandedShape,
ArrayRef<int64_t> collapsedShape,
PatternRewriter &rewriter);
unsigned getOrigOpNumDims() const { return reassociation.size(); }
unsigned getExpandedOpNumDims() const { return expandedOpNumDims; }
@ -533,6 +534,7 @@ public:
ArrayRef<int64_t> getExpandedShapeOfDim(unsigned i) const {
return expandedShapeMap[i];
}
ArrayRef<int64_t> getOriginalShape() const { return originalLoopExtent; }
private:
/// Reassociation from the dimensions in the original operation to the
@ -541,6 +543,8 @@ private:
/// Mapping from extent of loops in the original operation, to the extent of
/// loops in the expanded operation.
SmallVector<SmallVector<int64_t>> expandedShapeMap;
/// Extent of the loop in the original operation.
SmallVector<int64_t> originalLoopExtent;
unsigned expandedOpNumDims;
};
} // namespace
@ -549,6 +553,7 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
OpOperand *fusableOpOperand,
ArrayRef<AffineMap> reassociationMaps,
ArrayRef<int64_t> expandedShape,
ArrayRef<int64_t> collapsedShape,
PatternRewriter &rewriter) {
if (reassociationMaps.empty())
return failure();
@ -558,6 +563,8 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
linalgOp.getStaticLoopRanges();
if (!originalLoopRange)
return rewriter.notifyMatchFailure(linalgOp, "unable to find loop range");
originalLoopExtent.assign(originalLoopRange->begin(),
originalLoopRange->end());
reassociation.clear();
expandedShapeMap.clear();
@ -576,7 +583,7 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
// The remaining dimensions remain the same.
for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims()))
if (expandedShapeMap[i].empty())
expandedShapeMap[i] = {(*originalLoopRange)[i]};
expandedShapeMap[i] = {originalLoopExtent[i]};
// Compute reassociation map from the original op to the expanded op.
unsigned sum = 0;
@ -601,6 +608,30 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
LogicalResult isGenericOpExpandable(GenericOp genericOp,
const ExpansionInfo &expansionInfo,
PatternRewriter &rewriter) {
// Current reshape only supports expansion of a dynamic dim when only one of
// the expanded dims are dynamic.
for (auto originalShape : llvm::enumerate(expansionInfo.getOriginalShape()))
if (ShapedType::isDynamic(originalShape.value())) {
// All but one of the expanded dims must be static.
bool foundDynamicExpandedDim = false;
for (auto expandedShape :
expansionInfo.getExpandedShapeOfDim(originalShape.index())) {
if (ShapedType::isDynamic(expandedShape)) {
if (foundDynamicExpandedDim) {
return rewriter.notifyMatchFailure(
genericOp,
"cannot expanded dynamic dims into multiple dynamic dims");
}
foundDynamicExpandedDim = true;
}
}
if (!foundDynamicExpandedDim) {
return rewriter.notifyMatchFailure(
genericOp, "dynamic dim expansion needs at least one dynamic dim "
"in result shape");
}
}
if (!genericOp.hasIndexSemantics())
return success();
for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
@ -731,13 +762,16 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
RankedTensorType expandedType = isExpanding
? expandingReshapeOp.getResultType()
: collapsingReshapeOp.getSrcType();
RankedTensorType collapsedType = isExpanding
? expandingReshapeOp.getSrcType()
: collapsingReshapeOp.getResultType();
ExpansionInfo expansionInfo;
if (failed(expansionInfo.compute(
genericOp, fusableOpOperand,
isExpanding ? expandingReshapeOp.getReassociationMaps()
: collapsingReshapeOp.getReassociationMaps(),
expandedType.getShape(), rewriter)))
expandedType.getShape(), collapsedType.getShape(), rewriter)))
return llvm::None;
if (failed(isGenericOpExpandable(genericOp, expansionInfo, rewriter)))

View File

@ -507,3 +507,26 @@ func @unit_dim_reshape_expansion_full
// FOLDUNITDIM-SAME: ins(%[[ARG0]], %[[RESHAPE]] : tensor<1x?x1x2x1x4xf32>, tensor<1x?x1x2x1x4xf32>)
// FOLDUNITDIM-SAME: outs(%{{.+}} : tensor<1x?x1x2x1x4xf32>)
// -----
func @no_fuse_dynamic_dims(%arg0: tensor<?x?xf32>) -> tensor<?xf32> {
%c0 = arith.constant 0 : index
%0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<?x?xf32> into tensor<?xf32>
%1 = tensor.dim %0, %c0 : tensor<?xf32>
%2 = linalg.init_tensor [%1] : tensor<?xf32>
%3 = linalg.generic {
indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"]}
ins(%0 : tensor<?xf32>) outs(%2 : tensor<?xf32>) {
^bb0(%arg1 : f32, %arg2: f32):
%4 = arith.addf %arg1, %arg1 : f32
linalg.yield %4 : f32
} -> tensor<?xf32>
return %3 : tensor<?xf32>
}
// CHECK: func @no_fuse_dynamic_dims
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>
// CHECK: %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]]
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: ins(%[[RESHAPE]] : tensor<?xf32>)
// CHECK: return %[[GENERIC]]