forked from OSchip/llvm-project
[mlir][Linalg] Disable fusion of reshape with elementwise ops for purely dynamic cases.
`tensor.collapse_shape` op when fused with a consumer elementwise `linalg.generic` operation results in creation of tensor.expand_shape ops. In purely dynamic cases this can end up with a dynamic dimensions being expanded to more than one dynamic dimension. This is disallowed by the semantics of `tensor.expand_shape` operation. (While the transformation is itself correct, its a gap in the specification of `tensor.expand_shape` that is the issue). So disallow fusions which result in such a pattern. Differential Revision: https://reviews.llvm.org/D116703
This commit is contained in:
parent
0b5340acb7
commit
4317a3dfad
|
@ -524,6 +524,7 @@ public:
|
|||
LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand,
|
||||
ArrayRef<AffineMap> reassociationMaps,
|
||||
ArrayRef<int64_t> expandedShape,
|
||||
ArrayRef<int64_t> collapsedShape,
|
||||
PatternRewriter &rewriter);
|
||||
unsigned getOrigOpNumDims() const { return reassociation.size(); }
|
||||
unsigned getExpandedOpNumDims() const { return expandedOpNumDims; }
|
||||
|
@ -533,6 +534,7 @@ public:
|
|||
ArrayRef<int64_t> getExpandedShapeOfDim(unsigned i) const {
|
||||
return expandedShapeMap[i];
|
||||
}
|
||||
ArrayRef<int64_t> getOriginalShape() const { return originalLoopExtent; }
|
||||
|
||||
private:
|
||||
/// Reassociation from the dimensions in the original operation to the
|
||||
|
@ -541,6 +543,8 @@ private:
|
|||
/// Mapping from extent of loops in the original operation, to the extent of
|
||||
/// loops in the expanded operation.
|
||||
SmallVector<SmallVector<int64_t>> expandedShapeMap;
|
||||
/// Extent of the loop in the original operation.
|
||||
SmallVector<int64_t> originalLoopExtent;
|
||||
unsigned expandedOpNumDims;
|
||||
};
|
||||
} // namespace
|
||||
|
@ -549,6 +553,7 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
|
|||
OpOperand *fusableOpOperand,
|
||||
ArrayRef<AffineMap> reassociationMaps,
|
||||
ArrayRef<int64_t> expandedShape,
|
||||
ArrayRef<int64_t> collapsedShape,
|
||||
PatternRewriter &rewriter) {
|
||||
if (reassociationMaps.empty())
|
||||
return failure();
|
||||
|
@ -558,6 +563,8 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
|
|||
linalgOp.getStaticLoopRanges();
|
||||
if (!originalLoopRange)
|
||||
return rewriter.notifyMatchFailure(linalgOp, "unable to find loop range");
|
||||
originalLoopExtent.assign(originalLoopRange->begin(),
|
||||
originalLoopRange->end());
|
||||
|
||||
reassociation.clear();
|
||||
expandedShapeMap.clear();
|
||||
|
@ -576,7 +583,7 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
|
|||
// The remaining dimensions remain the same.
|
||||
for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims()))
|
||||
if (expandedShapeMap[i].empty())
|
||||
expandedShapeMap[i] = {(*originalLoopRange)[i]};
|
||||
expandedShapeMap[i] = {originalLoopExtent[i]};
|
||||
|
||||
// Compute reassociation map from the original op to the expanded op.
|
||||
unsigned sum = 0;
|
||||
|
@ -601,6 +608,30 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
|
|||
LogicalResult isGenericOpExpandable(GenericOp genericOp,
|
||||
const ExpansionInfo &expansionInfo,
|
||||
PatternRewriter &rewriter) {
|
||||
// Current reshape only supports expansion of a dynamic dim when only one of
|
||||
// the expanded dims are dynamic.
|
||||
for (auto originalShape : llvm::enumerate(expansionInfo.getOriginalShape()))
|
||||
if (ShapedType::isDynamic(originalShape.value())) {
|
||||
// All but one of the expanded dims must be static.
|
||||
bool foundDynamicExpandedDim = false;
|
||||
for (auto expandedShape :
|
||||
expansionInfo.getExpandedShapeOfDim(originalShape.index())) {
|
||||
if (ShapedType::isDynamic(expandedShape)) {
|
||||
if (foundDynamicExpandedDim) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
genericOp,
|
||||
"cannot expanded dynamic dims into multiple dynamic dims");
|
||||
}
|
||||
foundDynamicExpandedDim = true;
|
||||
}
|
||||
}
|
||||
if (!foundDynamicExpandedDim) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
genericOp, "dynamic dim expansion needs at least one dynamic dim "
|
||||
"in result shape");
|
||||
}
|
||||
}
|
||||
|
||||
if (!genericOp.hasIndexSemantics())
|
||||
return success();
|
||||
for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
|
||||
|
@ -731,13 +762,16 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
|
|||
RankedTensorType expandedType = isExpanding
|
||||
? expandingReshapeOp.getResultType()
|
||||
: collapsingReshapeOp.getSrcType();
|
||||
RankedTensorType collapsedType = isExpanding
|
||||
? expandingReshapeOp.getSrcType()
|
||||
: collapsingReshapeOp.getResultType();
|
||||
|
||||
ExpansionInfo expansionInfo;
|
||||
if (failed(expansionInfo.compute(
|
||||
genericOp, fusableOpOperand,
|
||||
isExpanding ? expandingReshapeOp.getReassociationMaps()
|
||||
: collapsingReshapeOp.getReassociationMaps(),
|
||||
expandedType.getShape(), rewriter)))
|
||||
expandedType.getShape(), collapsedType.getShape(), rewriter)))
|
||||
return llvm::None;
|
||||
|
||||
if (failed(isGenericOpExpandable(genericOp, expansionInfo, rewriter)))
|
||||
|
|
|
@ -507,3 +507,26 @@ func @unit_dim_reshape_expansion_full
|
|||
// FOLDUNITDIM-SAME: ins(%[[ARG0]], %[[RESHAPE]] : tensor<1x?x1x2x1x4xf32>, tensor<1x?x1x2x1x4xf32>)
|
||||
// FOLDUNITDIM-SAME: outs(%{{.+}} : tensor<1x?x1x2x1x4xf32>)
|
||||
|
||||
// -----
|
||||
|
||||
func @no_fuse_dynamic_dims(%arg0: tensor<?x?xf32>) -> tensor<?xf32> {
|
||||
%c0 = arith.constant 0 : index
|
||||
%0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<?x?xf32> into tensor<?xf32>
|
||||
%1 = tensor.dim %0, %c0 : tensor<?xf32>
|
||||
%2 = linalg.init_tensor [%1] : tensor<?xf32>
|
||||
%3 = linalg.generic {
|
||||
indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
|
||||
iterator_types = ["parallel"]}
|
||||
ins(%0 : tensor<?xf32>) outs(%2 : tensor<?xf32>) {
|
||||
^bb0(%arg1 : f32, %arg2: f32):
|
||||
%4 = arith.addf %arg1, %arg1 : f32
|
||||
linalg.yield %4 : f32
|
||||
} -> tensor<?xf32>
|
||||
return %3 : tensor<?xf32>
|
||||
}
|
||||
// CHECK: func @no_fuse_dynamic_dims
|
||||
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>
|
||||
// CHECK: %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]]
|
||||
// CHECK: %[[GENERIC:.+]] = linalg.generic
|
||||
// CHECK-SAME: ins(%[[RESHAPE]] : tensor<?xf32>)
|
||||
// CHECK: return %[[GENERIC]]
|
||||
|
|
Loading…
Reference in New Issue