[mlir][Linalg] Deprecate legacy reshape + generic op folding patterns.

These patterns have been superceded by the fusion by collapsing patterns.

Differential Revision: https://reviews.llvm.org/D124145
This commit is contained in:
Mahesh Ravishankar 2022-04-21 04:54:16 +00:00
parent e8572aca0c
commit 0c090dcc8a
8 changed files with 75 additions and 1027 deletions

View File

@ -90,31 +90,11 @@ def LinalgFoldUnitExtentDims : Pass<"linalg-fold-unit-extent-dims", ""> {
def LinalgElementwiseOpFusion : Pass<"linalg-fuse-elementwise-ops"> {
let summary = "Fuse elementwise operations on tensors";
let constructor = "mlir::createLinalgElementwiseOpFusionPass()";
let options = [
Option<"allowFoldingUnitDimReshapes", "allow-folding-unit-dim-reshapes",
"bool", /*default=*/"false",
"Allow fusing linalg.tensor_reshape ops that performs unit "
"dimension collapsing">
];
let dependentDialects = [
"AffineDialect", "linalg::LinalgDialect", "memref::MemRefDialect"
];
}
def LinalgFoldReshapeOpsByLinearization :
Pass<"linalg-fold-reshape-ops-by-linearization"> {
let summary = "Fold TensorReshapeOps with generic/indexed generic ops by "
"linearization";
let constructor = "mlir::createFoldReshapeOpsByLinearizationPass()";
let options = [
Option<"allowFoldingUnitDimReshapes", "allow-folding-unit-dim-reshapes",
"bool", /*default=*/"false",
"Allow fusing linalg.tensor_reshape ops that performs unit "
"dimension collapsing">
];
let dependentDialects = ["AffineDialect", "memref::MemRefDialect"];
}
def LinalgNamedOpConversion: Pass<"linalg-named-op-conversion"> {
let summary = "Convert from one named linalg op to another.";
let constructor = "mlir::createLinalgNamedOpConversionPass()";

View File

@ -37,10 +37,6 @@ struct LinalgElementwiseFusionOptions;
struct LinalgFusionOptions;
struct LinalgTilingOptions;
/// Default function to control reshape folding. Skips folding unit dimension
/// reshapes.
bool skipUnitDimReshape(const OpResult &producer, OpOperand &consumer);
//===----------------------------------------------------------------------===//
// Transformations exposed as function calls.
//===----------------------------------------------------------------------===//
@ -91,24 +87,6 @@ void populateFoldReshapeOpsByCollapsingPatterns(
void populateConstantFoldLinalgOperations(RewritePatternSet &patterns,
const ControlFusionFn &controlFn);
/// Patterns to fold a collapsing (expanding) tensor_reshape operation with its
/// producer (consumer) generic operation by linearizing the indexing map used
/// to access the source (target) of the reshape operation in the generic
/// operation.
/// TODO(ravishankarm): These patterns are to be deprecated in favor of using
/// the `populateFoldReshapeByCollapsingPatterns`.
void populateFoldReshapeOpsByLinearizationPatterns(RewritePatternSet &patterns);
/// Patterns to fold a collapsing (expanding) tensor_reshape operation with its
/// producer (consumer) generic operation by linearizing the indexing map used
/// to access the source (target) of the reshape operation in the generic
/// operation. The patterns are applied only when the tensor reshape involved is
/// collapsing (introducing) unit-extent dimensions.
/// TODO(ravishankarm): These patterns are to be deprecated in favor of using
/// the `populateFoldReshapeByCollapsingPatterns`.
void populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
RewritePatternSet &patterns);
/// Pattern to fuse a `tensor.pad` operation with the producer of its source,
/// if the producer is a `linalg` operation with all parallel iterator types.
void populateFuseTensorPadWithProducerLinalgOpPatterns(
@ -128,12 +106,6 @@ void populateInlineConstantOperandsPatterns(RewritePatternSet &patterns);
/// Patterns that are used to bubble up extract slice op above linalg op.
void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns);
/// Patterns to push reshape op towards the end of the graph in order to expose
/// more fusion opportunities.
/// TODO(ravishankarm): These patterns are to be deprecated in favor of using
/// the `populateFoldReshapeByCollapsingPatterns`.
void populatePushReshapeOpsPatterns(RewritePatternSet &patterns);
/// Perform standalone tiling of a single LinalgOp by `tileSizes`.
/// and permute the loop nest according to `interchangeVector`
/// The permutation is expressed as a list of integers that specify

View File

@ -392,263 +392,6 @@ private:
};
} // namespace
//===---------------------------------------------------------------------===//
// Methods and patterns that fuse reshape ops with elementwise operations by
// linearization of indexing maps.
//===---------------------------------------------------------------------===//
// TODO(ravishankarm): The indexing maps
// these produce in the general case are detrimental to transformations.
// These patterns are on deprecation path in favor of using fusion by
// collapsing, which covers the only legitimate use case of this pattern of
// folding unit-extent dims.
/// Linearize the expressions in `sourceMap` based on the `reassociationMaps`
/// provided, given the shape of the source tensor that corresponds to the
/// `sourceMap`. Note that this implicitly assumes that the tensors dimensions
/// are "row-major" ordered logically.
///
/// For example:
///
/// %0 = op ... : tensor<?x?x4x5xf32>
/// with output index_map `affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>`
///
/// and reshape:
/// %1 = tensor.collapse_shape %0 [[0], [0, 1, 2]] :
/// tensor<?x?x4x5xf32> into tensor<?x?xf32>
///
/// would be rewritten into:
/// %0 = op ... : tensor<?x?x4x5xf32>
/// with output index_map
/// `affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>`
template <typename TensorReshapeOp>
static AffineMap linearizeCollapsedDims(AffineMap sourceMap,
TensorReshapeOp reshapeOp) {
constexpr bool isExpanding =
std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value;
ArrayRef<int64_t> sourceShape =
(isExpanding ? reshapeOp.getResultType().getShape()
: reshapeOp.getSrcType().getShape());
SmallVector<AffineExpr> resultExprs;
ArrayRef<AffineExpr> sourceExprs = sourceMap.getResults();
MLIRContext *context = sourceMap.getContext();
// Compute the result exprs based on the reassociation maps.
for (auto &indices : reshapeOp.getReassociationIndices()) {
// Assume that they are in-order and contiguous (already checked in
// verifier).
assert(!indices.empty());
SmallVector<int64_t> sizes;
SmallVector<AffineExpr> dimExprs;
for (auto en : llvm::zip(sourceShape.slice(indices[0], indices.size()),
sourceExprs.slice(indices[0], indices.size()))) {
if (std::get<0>(en) == 1)
continue;
sizes.push_back(std::get<0>(en));
dimExprs.push_back(std::get<1>(en));
}
AffineExpr linearizedExpr =
makeCanonicalStridedLayoutExpr(sizes, dimExprs, context);
resultExprs.push_back(linearizedExpr);
}
// The new affine map cannot drop unused dimension but some new symbols may
// have been added. Create a map with at least as many dimensions/symbols as
// the original affine map.
int64_t maxDim = -1;
int64_t maxSym = -1;
getMaxDimAndSymbol<SmallVector<AffineExpr>>({resultExprs}, maxDim, maxSym);
unsigned numDims = std::max(unsigned(maxDim + 1), sourceMap.getNumDims());
unsigned numSyms = std::max(unsigned(maxSym + 1), sourceMap.getNumSymbols());
return AffineMap::get(numDims, numSyms, resultExprs, context);
}
// tensor::ExpandShapeOp is fusable with its consumer (i.e. reshape as a
// producer). Fusing when operand has higher rank will require use of mods and
// divs in the indexing maps of the fused op which would make it non-invertible.
static bool isTensorReshapeOpFoldableByLinearization(
tensor::ExpandShapeOp expandOp, AffineMap useIndexMap, bool asProducer) {
if (!asProducer)
return false;
return useIndexMap.isPermutation();
}
// tensor::CollapseShapeOp is fusable with its producer (i.e. reshape as a
// consumer).
static bool
isTensorReshapeOpFoldableByLinearization(tensor::CollapseShapeOp collapseOp,
AffineMap useIndexMap,
bool asProducer) {
if (asProducer)
return false;
return useIndexMap.isPermutation();
}
/// Check if the reshape operation is only expansion into/collapsing of
/// unit-dimension.
template <typename TensorReshapeOp>
static bool isUnitDimExpansionOnly(TensorReshapeOp reshapeOp) {
constexpr bool isExpanding =
std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value;
ArrayRef<int64_t> expandedShape =
(isExpanding ? reshapeOp.getResultType().getShape()
: reshapeOp.getSrcType().getShape());
for (auto &indices : reshapeOp.getReassociationIndices()) {
unsigned numUnitDims = 0;
for (int64_t position : indices)
if (expandedShape[position] == 1)
numUnitDims++;
if (numUnitDims != indices.size() - 1)
return false;
}
return true;
}
namespace {
/// Pattern to fold tensor_expand_shape op with its consumer by using the source
/// of the reshape op as the operand in the consumer (instead of the result of
/// the tensor_collapse_shape). The corresponding index map in the consumer
/// needs to be modified to linearize the folded dimension.
///
/// For example,
///
/// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
/// %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]]
/// tensor<?x?x?xf32> into tensor<?x?x4x?xf32>
/// %1 = linalg.generic { indexing_maps = [#map0, #map0, #map0], ... }
/// ins(%0, %arg1 : tensor<?x?x4x?xf32>, tensor<?x?x4x?xf32>) ...
/// -> tensor<?x?x4x?xf32>
///
/// can be folded into
///
/// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)>
/// #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
/// %0 = linalg.generic { indexing_maps = [#map0, #map1, #map1] ... }
/// ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x4x?xf32>) ...
/// -> tensor<?x?x4x?xf32>
template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp>
struct FoldProducerReshapeOpByLinearization
: public OpRewritePattern<GenericOp> {
using OpRewritePattern<GenericOp>::OpRewritePattern;
LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
if (!genericOp.hasTensorSemantics())
return failure();
SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands();
for (const auto &en : llvm::enumerate(inputOperands)) {
auto reshapeOp = en.value()->get().getDefiningOp<TensorReshapeOp>();
if (!reshapeOp)
continue;
if (!isTensorReshapeOpFoldableByLinearization(
reshapeOp, genericOp.getTiedIndexingMap(en.value()),
/*asProducer =*/true) ||
(foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp)))
continue;
// Compute the fused operands list,
SmallVector<Value> fusedOperands = genericOp.getInputOperands();
fusedOperands[en.index()] = reshapeOp.src();
SmallVector<Value> outputOperands = genericOp.getOutputOperands();
llvm::append_range(fusedOperands, outputOperands);
// Compute indexing_maps for the fused operation. The indexing_maps for
// the operands of the consumers that arent fused are the same.
SmallVector<AffineMap> fusedIndexMaps = genericOp.getIndexingMaps();
// Compute the indexing map to use for the result of the producer.
AffineMap modifiedMap =
linearizeCollapsedDims(fusedIndexMaps[en.index()], reshapeOp);
// The modified map cannot have symbols.
if (modifiedMap.getNumSymbols())
return failure();
for (AffineExpr expr : modifiedMap.getResults()) {
if (!expr.isPureAffine())
return failure();
}
fusedIndexMaps[en.index()] = modifiedMap;
// Further check that the resulting index maps can be fused and
// inverted. Without this the resultant op is not legal.
if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
return rewriter.notifyMatchFailure(
genericOp, "fused op loop bound computation failed");
}
rewriter.startRootUpdate(genericOp);
genericOp->setOperands(fusedOperands);
genericOp.indexing_mapsAttr(
rewriter.getAffineMapArrayAttr(fusedIndexMaps));
rewriter.finalizeRootUpdate(genericOp);
return success();
}
return failure();
}
};
/// Pattern to fold tensor_collapse_shape or tensor_expand_shape op with its
/// producer. The corresponding index map in the consumer needs to be modified
/// to linearize the folded dimension.
template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp>
struct FoldConsumerReshapeOpByLinearization
: public OpRewritePattern<TensorReshapeOp> {
using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
PatternRewriter &rewriter) const override {
GenericOp producer = reshapeOp.src().template getDefiningOp<GenericOp>();
if (!producer || !producer.hasTensorSemantics() ||
producer.getNumOutputs() != 1 ||
!isTensorReshapeOpFoldableByLinearization(
reshapeOp,
producer.getTiedIndexingMap(producer.getOutputOperand(0)),
/*asProducer =*/false) ||
(foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp)))
return failure();
// The indexing_maps for the operands of the fused operation are same as
// those for the operands of the producer.
SmallVector<AffineMap> fusedIndexMaps = producer.getIndexingMaps();
// Compute the indexing map to use for the operand of the producer.
AffineMap modifiedMap = linearizeCollapsedDims(
producer.getTiedIndexingMap(producer.getOutputOperand(0)), reshapeOp);
for (AffineExpr expr : modifiedMap.getResults()) {
if (!expr.isPureAffine()) {
return rewriter.notifyMatchFailure(
producer, "fused op indexing map is not affine");
}
}
fusedIndexMaps.back() = modifiedMap;
// Further check that the resulting index maps can be fused and
// inverted. Without this the resultant op is not legal.
if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
return rewriter.notifyMatchFailure(
producer, "fused op loop bound computation failed");
}
Location loc = producer.getLoc();
SmallVector<Value> inputOperands = producer.getInputOperands();
Value output = rewriter.create<TensorReshapeOp>(
loc, producer.getOutputOperand(0)->get(),
reshapeOp.getReassociationExprs());
auto fusedOp = rewriter.create<GenericOp>(
loc, reshapeOp.getResultType(),
/*inputs=*/inputOperands,
// TODO: handle outputs.
/*outputs=*/output, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
producer.iterator_types(),
/*doc=*/nullptr,
/*library_call=*/nullptr);
auto &fusedRegion = fusedOp->getRegion(0);
rewriter.cloneRegionBefore(producer->getRegion(0), fusedRegion,
fusedRegion.begin());
rewriter.replaceOp(reshapeOp, fusedOp->getResults());
return success();
}
};
} // namespace
//===---------------------------------------------------------------------===//
// Methods and patterns that fuse reshape ops with elementwise operations by
// expanding the dimensionality of the elementwise operations.
@ -1737,174 +1480,6 @@ private:
};
} // namespace
//===---------------------------------------------------------------------===//
// Methods and patterns to convert tensor.expand_shape -> linalg.generic
// into linalg.generic -> tensor.expand_shape, i.e. push the reshape down.
//===---------------------------------------------------------------------===//
// TODO(ravishankarm): This pattern is to be deprecated in favor of fusion by
// collapsing that provides a more general functionality. This pattern is very
// specific to a particular use case. The fusion by collapsing can provide the
// same control to clients using the control function there.
static SmallVector<ReassociationIndices>
getReassociationIndices(ArrayRef<AffineMap> maps) {
SmallVector<ReassociationIndices> reassociation;
for (AffineMap map : maps) {
ReassociationIndices indices;
for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
unsigned pos = map.getResult(i).cast<AffineDimExpr>().getPosition();
indices.push_back(pos);
}
reassociation.push_back(indices);
}
return reassociation;
}
namespace {
/// Pattern to move rank reducing reshape after an elementwise linalg generic
/// op. This is useful to expose more fusion opportunities between named ops and
/// generic ops. This can only be done if there is no broadcast or permuation
/// within the dimensions we need to merge.
///
/// For example,
///
/// %0 = tensor.expand_shape %A [[0, 1], [2]]
/// : tensor<12544x16xf32> into tensor<112x112x16xf32>
/// %2 = linalg.generic {indexing_maps = [
/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
/// affine_map<(d0, d1, d2) -> (d2)>,
/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types =
/// ["parallel", "parallel", "parallel"]} {
/// } -> tensor<112x112x16xf32>
///
/// into
///
/// %2 = linalg.generic {indexing_maps = [
/// affine_map<(d0, d1) -> (d0, d1)>,
/// affine_map<(d0, d1) -> (d1)>,
/// affine_map<(d0, d1) -> (d0, d1)>],
/// iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1
/// : tensor<12544x16xf32>, tensor<16xf32>) outs(%1 : tensor<12544x16xf32>) {
/// } -> tensor<12544x16xf32>
/// %3 = tensor.expand_shape %2 [[0, 1], [2]]
/// : tensor<12544x16xf32> into tensor<112x112x16xf32>
struct PushExpandingReshape : public OpRewritePattern<GenericOp> {
using OpRewritePattern<GenericOp>::OpRewritePattern;
LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
// Only apply to elementwise linalg on tensor.
if (!genericOp.hasTensorSemantics() || genericOp.hasIndexSemantics() ||
genericOp.getNumParallelLoops() != genericOp.getNumLoops())
return failure();
// Only support identity output maps. It could be extended to permuations if
// needed.
if (llvm::any_of(genericOp.getOutputOperands(), [&](OpOperand *opOperand) {
return !genericOp.getTiedIndexingMap(opOperand).isIdentity();
}))
return failure();
int64_t destRank = genericOp.getNumParallelLoops();
SmallVector<Value> newOperands = genericOp.getInputOperands();
tensor::ExpandShapeOp reshapeFound;
// 1. Look for tensor_expand_shape operands and figure out save the
// dimensions merged.
SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands();
for (const auto &en : llvm::enumerate(inputOperands)) {
auto reshapeOp =
en.value()->get().template getDefiningOp<tensor::ExpandShapeOp>();
if (!reshapeOp)
continue;
// TODO: We could support non-identity map as long as the merged
// dimensions are still contiguous.
if (!genericOp.getTiedIndexingMap(en.value()).isIdentity())
continue;
if (reshapeFound) {
// Only support a second reshape op if it has the same reassociate maps.
if (reshapeFound.getReassociationMaps() ==
reshapeOp.getReassociationMaps())
newOperands[en.index()] = reshapeOp.src();
continue;
}
reshapeFound = reshapeOp;
newOperands[en.index()] = reshapeOp.src();
}
if (!reshapeFound)
return failure();
// Calculate the reassociation indices and rassociated reverse map.
SmallVector<ReassociationIndices> reassociation =
getReassociationIndices(reshapeFound.getReassociationMaps());
SmallVector<unsigned> remap(destRank);
for (auto &indices : llvm::enumerate(reassociation)) {
for (int64_t index : indices.value()) {
remap[index] = indices.index();
}
}
// 2. Verify that we can merge the dimensions in the linalg and that we
// don't need to create new reshapes operands. Inserting new reshape
// operands would defeat the purpose of the transformation.
for (const auto &en : llvm::enumerate(inputOperands)) {
if (en.value()->get() == newOperands[en.index()]) {
AffineMap map = genericOp.getTiedIndexingMap(en.value());
for (unsigned i : llvm::seq(unsigned(0), map.getNumResults())) {
if (reassociation[remap[map.getDimPosition(i)]].size() > 1)
return failure();
}
}
}
// 3. Calculate the affine map remapping and the reassociation to apply to
// output tensors.
SmallVector<AffineMap> newMaps;
unsigned newRank = reassociation.size();
for (auto map : genericOp.getIndexingMaps()) {
SmallVector<AffineExpr> newExprs;
for (auto expr : map.getResults()) {
unsigned position = expr.template cast<AffineDimExpr>().getPosition();
// Skip dimension merged except for the last of the group.
if (reassociation[remap[position]].back() == position) {
newExprs.push_back(
getAffineDimExpr(remap[position], genericOp.getContext()));
}
}
newMaps.push_back(
AffineMap::get(newRank, 0, newExprs, genericOp.getContext()));
}
// 4. Reshape the output tensors.
SmallVector<Value> newOutputs;
SmallVector<Type> newOutputTypes;
for (auto output : genericOp.outputs()) {
auto newOutputType = RankedTensorType::get(
reshapeFound.getSrcType().getShape(),
output.getType().template cast<RankedTensorType>().getElementType());
Value newOutput = rewriter.create<tensor::CollapseShapeOp>(
genericOp->getLoc(), newOutputType, output, reassociation);
newOutputTypes.push_back(newOutputType);
newOutputs.push_back(newOutput);
}
// 5. Create a new generic op with lowerer rank.
SmallVector<StringRef> iteratorTypes(newRank,
getParallelIteratorTypeName());
auto newOp = rewriter.create<GenericOp>(genericOp->getLoc(), newOutputTypes,
newOperands, newOutputs, newMaps,
iteratorTypes);
rewriter.inlineRegionBefore(genericOp.region(), newOp.region(),
newOp.region().begin());
// 6. Reshape the so that the type matches the uses.
SmallVector<Value> newResults;
for (const auto &result : llvm::enumerate(newOp->getResults())) {
newResults.push_back(rewriter.create<tensor::ExpandShapeOp>(
genericOp->getLoc(), genericOp.getOutputTensorTypes()[result.index()],
result.value(), reassociation));
}
rewriter.replaceOp(genericOp, newResults);
return success();
}
};
} // namespace
//===---------------------------------------------------------------------===//
// Methods and patterns that fuse constants with linalg.generic operations.
//===---------------------------------------------------------------------===//
@ -2093,27 +1668,6 @@ struct FoldFillWithGenericOp : public OpRewritePattern<GenericOp> {
}
};
} // namespace
//===---------------------------------------------------------------------===//
// Methods that add patterns described in this file to a pattern list.
//===---------------------------------------------------------------------===//
void mlir::linalg::populateFoldReshapeOpsByLinearizationPatterns(
RewritePatternSet &patterns) {
patterns.add<
FoldProducerReshapeOpByLinearization<false, tensor::CollapseShapeOp>,
FoldProducerReshapeOpByLinearization<false, tensor::ExpandShapeOp>,
FoldConsumerReshapeOpByLinearization<false, tensor::CollapseShapeOp>>(
patterns.getContext());
}
void mlir::linalg::populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
RewritePatternSet &patterns) {
patterns
.add<FoldProducerReshapeOpByLinearization<true, tensor::CollapseShapeOp>,
FoldProducerReshapeOpByLinearization<true, tensor::ExpandShapeOp>,
FoldConsumerReshapeOpByLinearization<true, tensor::CollapseShapeOp>>(
patterns.getContext());
}
void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
RewritePatternSet &patterns,
@ -2140,28 +1694,10 @@ void mlir::linalg::populateElementwiseOpsFusionPatterns(
RemoveOutsDependency>(context);
}
void mlir::linalg::populatePushReshapeOpsPatterns(RewritePatternSet &patterns) {
auto *context = patterns.getContext();
patterns.add<PushExpandingReshape>(context);
}
//===---------------------------------------------------------------------===//
// Passes
//===---------------------------------------------------------------------===//
bool mlir::linalg::skipUnitDimReshape(const OpResult &producer,
OpOperand &consumer) {
if (auto producerCollapseOp =
dyn_cast<tensor::CollapseShapeOp>(producer.getOwner())) {
return !isUnitDimExpansionOnly(producerCollapseOp);
}
if (auto consumerExpandOp =
dyn_cast<tensor::ExpandShapeOp>(consumer.getOwner())) {
return !isUnitDimExpansionOnly(consumerExpandOp);
}
return true;
}
namespace {
/// Pass that fuses generic ops on tensors. Used only for testing.
@ -2186,9 +1722,7 @@ struct LinalgElementwiseOpFusionPass
// Add elementwise op fusion patterns.
populateElementwiseOpsFusionPatterns(patterns, defaultControlFn);
populateFoldReshapeOpsByExpansionPatterns(
patterns,
allowFoldingUnitDimReshapes ? defaultControlFn : skipUnitDimReshape);
populateFoldReshapeOpsByExpansionPatterns(patterns, defaultControlFn);
// Add the sparse tensor rewriting patterns.
populateSparseTensorRewriting(patterns);
@ -2212,27 +1746,8 @@ struct LinalgElementwiseOpFusionPass
}
};
/// Pass to test folding of reshape ops with generic ops by linearization.
struct FoldReshapeOpsByLinearizationPass
: public LinalgFoldReshapeOpsByLinearizationBase<
FoldReshapeOpsByLinearizationPass> {
void runOnOperation() override {
Operation *op = getOperation();
RewritePatternSet patterns(op->getContext());
populateFoldReshapeOpsByLinearizationPatterns(patterns);
if (allowFoldingUnitDimReshapes) {
populateFoldUnitDimsReshapeOpsByLinearizationPatterns(patterns);
}
(void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns));
}
};
} // namespace
std::unique_ptr<Pass> mlir::createLinalgElementwiseOpFusionPass() {
return std::make_unique<LinalgElementwiseOpFusionPass>();
}
std::unique_ptr<Pass> mlir::createFoldReshapeOpsByLinearizationPass() {
return std::make_unique<FoldReshapeOpsByLinearizationPass>();
}

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=push-expanding-reshape -split-input-file | FileCheck %s
// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=fuse-with-reshape-by-collapsing -split-input-file | FileCheck %s
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1) -> (d1)>
@ -124,30 +124,3 @@ func.func @type_correctness(%arg0 : tensor<6x5xi32>, %arg1 : tensor<5xf32>,
// CHECK-SAME: outs(%{{.+}} : tensor<6x5xf32>)
// CHECK: tensor.expand_shape %[[OP]]
// CHECK-SAME: tensor<6x5xf32> into tensor<2x3x5xf32>
// -----
func.func @generic_op_index_semantics(%A: tensor<?x16xi64>, %B: tensor<16xi64>, %init: tensor<?x112x16xi64>) -> tensor<?x112x16xi64> {
%0 = tensor.expand_shape %A [[0, 1], [2]]
: tensor<?x16xi64> into tensor<?x112x16xi64>
%2 = linalg.generic {indexing_maps = [
affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d2)>,
affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
iterator_types = ["parallel", "parallel", "parallel"]}
ins(%0, %B : tensor<?x112x16xi64>, tensor<16xi64>)
outs(%init : tensor<?x112x16xi64>) {
^bb0(%arg1: i64, %arg2: i64, %arg3: i64): // no predecessors
%index = linalg.index 0 : index
%1 = arith.index_cast %index : index to i64
%add = arith.addi %arg1, %1 : i64
%s = arith.subi %add, %arg2 : i64
linalg.yield %s : i64
} -> tensor<?x112x16xi64>
return %2 : tensor<?x112x16xi64>
}
// CHECK: func @generic_op_index_semantics
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x16xi64>
// CHECK: %[[RESHAPE:.+]] = tensor.expand_shape %[[ARG0]]
// CHECK: %[[RESULT:.+]] = linalg.generic
// CHECK-SAME: ins(%[[RESHAPE]]
// CHECK: return %[[RESULT]]

View File

@ -1,5 +1,5 @@
// RUN: mlir-opt %s -linalg-fuse-elementwise-ops="allow-folding-unit-dim-reshapes=false" -split-input-file | FileCheck %s
// RUN: mlir-opt %s -linalg-fuse-elementwise-ops="allow-folding-unit-dim-reshapes=true" -split-input-file | FileCheck %s --check-prefix=FOLDUNITDIM
// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=fuse-with-reshape-by-expansion -split-input-file | FileCheck %s
#map0 = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
#map1 = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
#map2 = affine_map<(d0, d1, d2) -> ()>
@ -14,7 +14,7 @@ func.func @generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xf32>,
indexing_maps = [#map0, #map1, #map2, #map1],
iterator_types = ["parallel", "parallel", "parallel"]}
ins(%0, %arg1, %arg2 : tensor<?x?x?xf32>, tensor<?x?x?xf32>, f32)
outs(%0 : tensor<?x?x?xf32>) {
outs(%arg1 : tensor<?x?x?xf32>) {
^bb0(%arg3: f32, %arg4: f32, %arg5: f32, %s: f32):
%1 = arith.mulf %arg3, %arg4 : f32
%2 = arith.addf %1, %arg5 : f32
@ -30,15 +30,15 @@ func.func @generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xf32>,
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x4x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: f32
// CHECK: %[[T0:.+]] = tensor.collapse_shape %[[ARG0]]
// CHECK-SAME: [0], [1, 2], [3]
// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]]
// CHECK-SAME: [0], [1], [2, 3]
// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG1]]
// CHECK-SAME: [0], [1], [2, 3]
// CHECK: %[[T3:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP5]], #[[MAP6]], #[[MAP7]], #[[MAP6]]]
// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"]
// CHECK-SAME: ins(%[[ARG0]], %[[T1]], %[[ARG2]] : tensor<?x?x4x?xf32>, tensor<?x?x?x4xf32>, f32)
// CHECK-SAME: outs(%{{.+}} : tensor<?x?x?x4xf32>)
// CHECK-SAME: outs(%[[T2]] : tensor<?x?x?x4xf32>)
// CHECK: %[[T4:.+]] = tensor.collapse_shape %[[T3]]
// CHECK-SAME: [0], [1], [2, 3]
// CHECK-SAME: tensor<?x?x?x4xf32> into tensor<?x?x?xf32>
@ -80,12 +80,14 @@ func.func @generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?xf32>,
// CHECK-SAME: tensor<?x?xf32> into tensor<?x4x?x5xf32>
// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]]
// CHECK-SAME: [0], [1, 2, 3]
// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG0]]
// CHECK-SAME: [0], [1, 2, 3]
// CHECK-SAME: tensor<?x?xf32> into tensor<?x4x?x5xf32>
// CHECK: %[[T3:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP3]], #[[MAP2]]]
// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"]
// CHECK-SAME: ins(%[[T0]], %[[T1]], %[[ARG2]] : tensor<?x4x?x5xf32>, tensor<?x4x?x5xf32>, f32)
// CHECK-SAME: outs(%{{.+}} : tensor<?x4x?x5xf32>)
// CHECK-SAME: outs(%[[T2]] : tensor<?x4x?x5xf32>)
// CHECK: return %[[T3]] : tensor<?x4x?x5xf32>
@ -121,11 +123,14 @@ func.func @reshape_as_consumer_permutation
// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]]
// CHECK-SAME: [0, 1, 2], [3]
// CHECK-SAME: tensor<?x?xf32> into tensor<3x4x?x?xf32>
// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG0]]
// CHECK-SAME: [0, 1], [2], [3, 4, 5]]
// CHECK-SAME: tensor<?x?x?xf32> into tensor<?x2x?x3x4x?xf32>
// CHECK: %[[T3:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP8]], #[[MAP9]], #[[MAP10]]]
// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]
// CHECK-SAME: ins(%[[T0]], %[[T1]] : tensor<3x4x?x?x2x?xf32>, tensor<3x4x?x?xf32>)
// CHECK-SAME: outs(%{{.+}} : tensor<?x2x?x3x4x?xf32>)
// CHECK-SAME: outs(%[[T2]] : tensor<?x2x?x3x4x?xf32>)
// CHECK: return %[[T3]] : tensor<?x2x?x3x4x?xf32>
// -----
@ -155,14 +160,19 @@ func.func @generic_op_reshape_consumer_static(%arg0: tensor<264x4xf32>)
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK: func @generic_op_reshape_consumer_static
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<264x4xf32>
// CHECK-DAG: %[[CST:.+]] = arith.constant
// CHECK-SAME: : tensor<8x33x4xf32>
// CHECK-DAG: %[[INIT:.+]] = linalg.init_tensor [264, 4]
// CHECK: %[[T0:.+]] = tensor.expand_shape %[[ARG0]]
// CHECK-SAME: [0, 1], [2]
// CHECK-SAME: tensor<264x4xf32> into tensor<8x33x4xf32>
// CHECK: %[[T1:.+]] = linalg.init_tensor [8, 33, 4]
// CHECK: %[[T1:.+]] = tensor.expand_shape %[[INIT]]
// CHECK-SAME: [0, 1], [2]
// CHECK-SAME: : tensor<264x4xf32> into tensor<8x33x4xf32>
// CHECK: %[[T2:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]]]
// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP2]]]
// CHECK-SAME: ["parallel", "parallel", "parallel"]
// CHECK-SAME: ins(%[[T0]] : tensor<8x33x4xf32>)
// CHECK-SAME: ins(%[[T0]], %[[CST]] :
// CHECK-SAME: outs(%[[T1]] : tensor<8x33x4xf32>)
// CHECK: return %[[T2]] : tensor<8x33x4xf32>
@ -246,7 +256,8 @@ func.func @indexed_producer_reshape_consumer_fusion(%arg0 : tensor<?x?xi32>,
}
// Only check the body in the indexed version of the test.
// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0 + d1 * 5 + d2 * 20)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 4)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 5)>
// CHECK: func @indexed_producer_reshape_consumer_fusion
// CHECK: linalg.generic
// CHECK: ^{{.*}}(
@ -256,11 +267,12 @@ func.func @indexed_producer_reshape_consumer_fusion(%arg0 : tensor<?x?xi32>,
// CHECK-DAG: %[[IDX1:.+]] = linalg.index 1 : index
// CHECK-DAG: %[[IDX2:.+]] = linalg.index 2 : index
// CHECK-DAG: %[[IDX3:.+]] = linalg.index 3 : index
// CHECK-DAG: %[[T3:.+]] = affine.apply #[[MAP]](%[[IDX3]], %[[IDX2]], %[[IDX1]])
// CHECK: %[[T1:.+]] = affine.apply #[[MAP1]](%[[IDX2]], %[[IDX1]])
// CHECK: %[[T2:.+]] = affine.apply #[[MAP2]](%[[IDX3]], %[[T1]])
// CHECK: %[[T4:.+]] = arith.muli %[[ARG3]], %[[ARG4]]
// CHECK: %[[T5:.+]] = arith.index_cast %[[IDX0]]
// CHECK: %[[T6:.+]] = arith.addi %[[T4]], %[[T5]]
// CHECK: %[[T7:.+]] = arith.index_cast %[[T3]]
// CHECK: %[[T7:.+]] = arith.index_cast %[[T2]]
// CHECK: %[[T8:.+]] = arith.addi %[[T6]], %[[T7]]
// CHECK: linalg.yield %[[T8]]
@ -295,24 +307,29 @@ func.func @reshape_as_consumer_permutation
return %d : tensor<2x3x4x5x6x7xi32>
}
// -----
// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d0, d1, d5)>
// CHECK-DAG: #[[MAP6:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)>
// CHECK-DAG: #[[MAP7:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d2, d3, d4)>
// CHECK-DAG: #[[MAP8:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 3)>
// CHECK-DAG: #[[MAP9:.+]] = affine_map<(d0, d1, d2) -> (d0 + d1 * 7 + d2 * 42)>
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d0, d1, d5)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d2, d3, d4)>
// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 3)>
// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 6)>
// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 7)>
// CHECK: func @reshape_as_consumer_permutation
// CHECK-SAME: %[[ARG0:.+]]: tensor<210x6x4xi32>
// CHECK-SAME: %[[ARG1:.+]]: tensor<210x4xi32>
// CHECK-DAG: %[[INIT:.+]] = linalg.init_tensor [6, 4, 210]
// CHECK-DAG: %[[T1:.+]] = tensor.expand_shape %[[ARG0]]
// CHECK-SAME: [0, 1, 2], [3, 4], [5]
// CHECK-DAG: %[[T2:.+]] = tensor.expand_shape %[[ARG1]]
// CHECK-SAME: [0, 1, 2], [3]
// CHECK-DAG: %[[T0:.+]] = linalg.init_tensor [2, 3, 4, 5, 6, 7]
// CHECK-DAG: %[[T3:.+]] = tensor.expand_shape %[[INIT]]
// CHECK-SAME: [0, 1], [2], [3, 4, 5]
// CHECK-SAME: : tensor<6x4x210xi32> into tensor<2x3x4x5x6x7xi32>
// CHECK: %[[T4:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP5]], #[[MAP6]], #[[MAP7]]]
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
// CHECK-SAME: ins(%[[T1]], %[[T2]] : tensor<5x6x7x2x3x4xi32>, tensor<5x6x7x4xi32>)
// CHECK-SAME: outs(%[[T0]] : tensor<2x3x4x5x6x7xi32>)
// CHECK-SAME: outs(%[[T3]] : tensor<2x3x4x5x6x7xi32>)
// CHECK: ^{{.+}}(
// CHECK-SAME: %[[ARG8:[a-zA-Z0-9]+]]: i32, %[[ARG9:[a-zA-Z0-9]+]]: i32,
// CHECK-SAME: %[[ARG10:[a-zA-Z0-9]+]]: i32)
@ -322,15 +339,16 @@ func.func @reshape_as_consumer_permutation
// CHECK-DAG: %[[IDX3:.+]] = linalg.index 3 : index
// CHECK-DAG: %[[IDX4:.+]] = linalg.index 4 : index
// CHECK-DAG: %[[IDX5:.+]] = linalg.index 5 : index
// CHECK-DAG: %[[T5:.+]] = affine.apply #[[MAP8]](%[[IDX1]], %[[IDX0]])
// CHECK-DAG: %[[T6:.+]] = affine.apply #[[MAP9]](%[[IDX4]], %[[IDX3]], %[[IDX2]])
// CHECK-DAG: %[[T7:.+]] = arith.addi %[[ARG8]], %[[ARG9]]
// CHECK: %[[T8:.+]] = arith.index_cast %[[T5]]
// CHECK: %[[T9:.+]] = arith.addi %[[T7]], %[[T8]]
// CHECK: %[[T10:.+]] = arith.index_cast %[[T6]]
// CHECK: %[[T11:.+]] = arith.addi %[[T9]], %[[T10]]
// CHECK: %[[T12:.+]] = arith.index_cast %[[IDX5]]
// CHECK: %[[T13:.+]] = arith.addi %[[T11]], %[[T12]]
// CHECK-DAG: %[[T5:.+]] = affine.apply #[[MAP3]](%[[IDX1]], %[[IDX0]])
// CHECK-DAG: %[[T6:.+]] = affine.apply #[[MAP4]](%[[IDX3]], %[[IDX2]])
// CHECK-DAG: %[[T7:.+]] = affine.apply #[[MAP5]](%[[IDX4]], %[[T6]])
// CHECK-DAG: %[[T8:.+]] = arith.addi %[[ARG8]], %[[ARG9]]
// CHECK: %[[T9:.+]] = arith.index_cast %[[T5]]
// CHECK: %[[T10:.+]] = arith.addi %[[T8]], %[[T9]]
// CHECK: %[[T11:.+]] = arith.index_cast %[[T7]]
// CHECK: %[[T12:.+]] = arith.addi %[[T10]], %[[T11]]
// CHECK: %[[T13:.+]] = arith.index_cast %[[IDX5]]
// CHECK: %[[T14:.+]] = arith.addi %[[T12]], %[[T13]]
// -----
@ -421,94 +439,18 @@ func.func @generic_op_reshape_consumer_fusion_projected(%arg0 : tensor<?x?xf32>,
// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]]
// CHECK-SAME: [0, 1, 2], [3]
// CHECK-SAME: tensor<?x?xf32> into tensor<?x4x5x?xf32>
// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG0]]
// CHECK-SAME: [0], [1, 2, 3]
// CHECK-SAME: tensor<?x?xf32> into tensor<?x?x4x5xf32>
// CHECK: %[[T3:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP4]], #[[MAP4]], #[[MAP5]]]
// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"]
// CHECK-SAME: ins(%[[T0]], %[[T1]] : tensor<?x4x5x?xf32>, tensor<?x4x5x?xf32>)
// CHECK-SAME: outs(%{{.+}} : tensor<?x?x4x5xf32>)
// CHECK-SAME: outs(%[[T2]] : tensor<?x?x4x5xf32>)
// CHECK: return %[[T3]] : tensor<?x?x4x5xf32>
// -----
func.func @unit_dim_reshape_expansion(%arg0 : tensor<1x5xf32>) -> tensor<5x5xf32> {
%0 = tensor.collapse_shape %arg0 [[0, 1]]
: tensor<1x5xf32> into tensor<5xf32>
%1 = linalg.init_tensor [5, 5] : tensor<5x5xf32>
%2 = linalg.generic
{indexing_maps = [affine_map<(d0, d1) -> (d0)>,
affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%0 : tensor<5xf32>) outs(%1 : tensor<5x5xf32>) {
^bb0(%arg2: f32, %arg3: f32):
linalg.yield %arg2 : f32
} -> tensor<5x5xf32>
return %2 : tensor<5x5xf32>
}
// CHECK: func @unit_dim_reshape_expansion
// CHECK-DAG: tensor.collapse_shape
// CHECK-DAG: linalg.init_tensor
// CHECK: linalg.generic
// -----
func.func @unit_dim_reshape_collapse(%arg0 : tensor<5xf32>) -> tensor<5x1x5xf32> {
%0 = linalg.init_tensor [5, 5] : tensor<5x5xf32>
%1 = linalg.generic
{indexing_maps = [affine_map<(d0, d1) -> (d0)>,
affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%arg0 : tensor<5xf32>) outs(%0 : tensor<5x5xf32>) {
^bb0(%arg2: f32, %arg3: f32):
linalg.yield %arg2 : f32
} -> tensor<5x5xf32>
%2 = tensor.expand_shape %1 [[0, 1], [2]]
: tensor<5x5xf32> into tensor<5x1x5xf32>
return %2 : tensor<5x1x5xf32>
}
// CHECK: func @unit_dim_reshape_collapse
// CHECK: linalg.init_tensor
// CHECK: linalg.generic
// CHECK: tensor.expand_shape
// -----
func.func @unit_dim_reshape_expansion_full
(%arg0 : tensor<1x?x1x2x1x4xf32>, %arg1 : tensor<?x2x4xf32>)
-> tensor<?x2x4xf32> {
%c1 = arith.constant 1 : index
%0 = tensor.collapse_shape %arg0 [[0, 1, 2], [3, 4], [5]]
: tensor<1x?x1x2x1x4xf32> into tensor<?x2x4xf32>
%1 = tensor.dim %arg0, %c1 : tensor<1x?x1x2x1x4xf32>
%2 = linalg.init_tensor [%1, 2, 4] : tensor<?x2x4xf32>
%3 = linalg.generic
{indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
iterator_types = ["parallel", "parallel", "parallel"]}
ins(%0, %arg1 : tensor<?x2x4xf32>, tensor<?x2x4xf32>)
outs(%2 : tensor<?x2x4xf32>) {
^bb0(%arg2: f32, %arg3: f32, %arg4: f32):
%4 = arith.mulf %arg2, %arg3 : f32
linalg.yield %4 : f32
} -> tensor<?x2x4xf32>
return %3 : tensor<?x2x4xf32>
}
// CHECK: func @unit_dim_reshape_expansion_full
// CHECK-DAG: tensor.collapse_shape
// CHECK-DAG: linalg.init_tensor
// CHECK: linalg.generic
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<?x2x4xf32>, tensor<?x2x4xf32>)
// FOLDUNITDIM: func @unit_dim_reshape_expansion_full
// FOLDUNITDIM-SAME: %[[ARG0:.+]]: tensor<1x?x1x2x1x4xf32>
// FOLDUNITDIM-SAME: %[[ARG1:.+]]: tensor<?x2x4xf32>
// FOLDUNITDIM-DAG: %[[RESHAPE:.+]] = tensor.expand_shape %[[ARG1]]
// FOLDUNITDIM: linalg.generic
// FOLDUNITDIM-SAME: ins(%[[ARG0]], %[[RESHAPE]] : tensor<1x?x1x2x1x4xf32>, tensor<1x?x1x2x1x4xf32>)
// FOLDUNITDIM-SAME: outs(%{{.+}} : tensor<1x?x1x2x1x4xf32>)
// -----
func.func @no_fuse_dynamic_dims(%arg0: tensor<?x?xf32>) -> tensor<?xf32> {
%c0 = arith.constant 0 : index
%0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<?x?xf32> into tensor<?xf32>
@ -554,7 +496,6 @@ func.func @no_fuse_mismatched_dynamism(%arg0: tensor<2x1xi64>, %arg1: tensor<?xi
// CHECK-SAME: %[[ARG0:.+]]: tensor<2x1xi64>
// CHECK-SAME: %[[ARG1:.+]]: tensor<?xi64>
// CHECK: %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]]
// CHECK: %[[CAST:.+]] = tensor.cast %[[ARG1]] : tensor<?xi64> to tensor<2xi64>
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: ins(%[[RESHAPE]], %[[CAST]] : tensor<2xi64>, tensor<2xi64>)
// CHECK-SAME: ins(%[[RESHAPE]], %[[ARG1]] : tensor<2xi64>, tensor<?xi64>)
// CHECK: return %[[GENERIC]]

View File

@ -1,287 +0,0 @@
// RUN: mlir-opt -split-input-file -linalg-fold-reshape-ops-by-linearization %s | FileCheck %s
// Note: These tests fuse the reshape ops by linearization. This can create
// indexing maps which are hard to analyse later on. These patterns are useful
// only if the folded dimensions in the reshape op are unit extent. Tests here
// are more general for testing purposes, but use of these pattern for non-unit
// dimensions should be deprecated.
#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
func.func @generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x?xi32>)
-> tensor<?x?x4x?xi32> {
%0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]] :
tensor<?x?x?xi32> into tensor<?x?x4x?xi32>
%1 = linalg.generic {
indexing_maps = [#map0, #map0],
iterator_types = ["parallel", "parallel", "parallel", "parallel"] }
ins(%0 : tensor<?x?x4x?xi32>)
outs(%0 : tensor<?x?x4x?xi32>) {
^bb0(%arg6: i32, %arg7 : i32):
%idx = linalg.index 0 : index
%2 = arith.index_cast %idx : index to i32
%3 = arith.addi %arg6, %2 : i32
linalg.yield %3 : i32
} -> tensor<?x?x4x?xi32>
return %1 : tensor<?x?x4x?xi32>
}
// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)>
// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK: func @generic_op_reshape_producer_fusion
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xi32>
// CHECK: %[[T0:.+]] = tensor.expand_shape %[[ARG0]]
// CHECK-SAME: [0], [1, 2], [3]
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP3]], #[[MAP4]]]
// CHECK-SAME: ins(%[[ARG0]] : tensor<?x?x?xi32>)
// CHECK-SAME: outs(%[[T0]] : tensor<?x?x4x?xi32>)
// CHECK: %[[IDX:.+]] = linalg.index 0 : index
// CHECK-NEXT: %[[IDX_CASTED:.+]] = arith.index_cast %[[IDX]] : index to i32
// -----
#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
func.func @generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?x4x5xi32>)
-> tensor<?x?xi32> {
%0 = linalg.generic {
indexing_maps = [#map0, #map0],
iterator_types = ["parallel", "parallel", "parallel", "parallel"] }
ins(%arg0 : tensor<?x?x4x5xi32>) outs(%arg0 : tensor<?x?x4x5xi32>) {
^bb0(%arg6: i32, %arg7: i32):
%idx = linalg.index 0 : index
%2 = arith.index_cast %idx : index to i32
%3 = arith.addi %arg6, %2 : i32
linalg.yield %3 : i32
} -> tensor<?x?x4x5xi32>
%1 = tensor.collapse_shape %0 [[0], [1, 2, 3]] :
tensor<?x?x4x5xi32> into tensor<?x?xi32>
return %1 : tensor<?x?xi32>
}
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>
// CHECK: func @generic_op_reshape_consumer_fusion
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x4x5xi32>
// CHECK: %[[T0:.+]] = tensor.collapse_shape %[[ARG0]]
// CHECK-SAME: [0], [1, 2, 3]
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP3]]]
// CHECK-SAME: outs(%[[T0]] : tensor<?x?xi32>)
// CHECK: %[[IDX:.+]] = linalg.index 0 : index
// CHECK-NEXT: %[[IDX_CASTED:.+]] = arith.index_cast %[[IDX]] : index to i32
// CHECK-NOT: tensor.collapse_shape
// -----
#map2 = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
func.func @generic_op_021_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<3x7x5xf32> {
%0 = tensor.expand_shape %arg0 [[0], [1, 2]]
: tensor<3x35xf32> into tensor<3x5x7xf32>
%1 = linalg.init_tensor [3, 7, 5] : tensor<3x7x5xf32>
%2 = linalg.generic
{indexing_maps = [#map2, #map3],
iterator_types = ["parallel", "parallel", "parallel"]}
ins(%0 : tensor<3x5x7xf32>) outs(%1 : tensor<3x7x5xf32>) {
^bb0(%arg2: f32, %arg3 : f32):
linalg.yield %arg2 : f32
} -> tensor<3x7x5xf32>
return %2 : tensor<3x7x5xf32>
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1 + d2 * 7)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK: func @generic_op_021_permultation_reshape_producer_fusion
// CHECK-NOT: tensor.expand_shape
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
// -----
#map2 = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
#map3 = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
func.func @generic_op_120_permutation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<5x7x3xf32> {
%0 = tensor.expand_shape %arg0 [[0], [1, 2]]
: tensor<3x35xf32> into tensor<3x5x7xf32>
%1 = linalg.init_tensor [5, 7, 3] : tensor<5x7x3xf32>
%2 = linalg.generic
{indexing_maps = [#map2, #map3],
iterator_types = ["parallel", "parallel", "parallel"]}
ins(%0 : tensor<3x5x7xf32>) outs(%1 : tensor<5x7x3xf32>) {
^bb0(%arg2: f32, %arg3: f32):
linalg.yield %arg2 : f32
} -> tensor<5x7x3xf32>
return %2 : tensor<5x7x3xf32>
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d1, d0 * 7 + d2)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
// CHECK: func @generic_op_120_permutation_reshape_producer_fusion
// CHECK-NOT: tensor.expand_shape
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
// -----
#map0 = affine_map<(d0, d1, d2) -> (d0)>
#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
func.func @generic_op_102_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<5x3x7xf32> {
%0 = tensor.expand_shape %arg0 [[0], [1, 2]]
: tensor<3x35xf32> into tensor<3x5x7xf32>
%1 = linalg.init_tensor [5, 3, 7] : tensor<5x3x7xf32>
%2 = linalg.generic
{indexing_maps = [#map2, #map3],
iterator_types = ["parallel", "parallel", "parallel"]}
ins(%0 : tensor<3x5x7xf32>) outs(%1 : tensor<5x3x7xf32>) {
^bb0(%arg2: f32, %arg3: f32):
linalg.yield %arg2 : f32
} -> tensor<5x3x7xf32>
return %2 : tensor<5x3x7xf32>
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d1, d0 * 7 + d2)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK: func @generic_op_102_permultation_reshape_producer_fusion
// CHECK-NOT: tensor.expand_shape
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
// -----
#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d0)>
#map3 = affine_map<(d0, d1, d2) -> (d1, d2)>
func.func @generic_op_102_permultation_reshape_consumer_fusion(%arg0 : tensor<3x5x7xf32>) -> tensor<5x21xf32> {
%0 = linalg.init_tensor [5, 3, 7] : tensor<5x3x7xf32>
%1 = linalg.generic
{indexing_maps = [#map0, #map1],
iterator_types = ["parallel", "parallel", "parallel"]}
ins(%arg0 : tensor<3x5x7xf32>) outs(%0 : tensor<5x3x7xf32>) {
^bb0(%arg2: f32, %arg3 : f32):
linalg.yield %arg2 : f32
} -> tensor<5x3x7xf32>
%2 = tensor.collapse_shape %1 [[0], [1, 2]]
: tensor<5x3x7xf32> into tensor<5x21xf32>
return %2 : tensor<5x21xf32>
}
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2) -> (d1, d0 * 7 + d2)>
// CHECK: func @generic_op_102_permultation_reshape_consumer_fusion
// CHECK-SAME: %[[ARG0:.+]]: tensor<3x5x7xf32>
// CHECK: %[[T0:.+]] = linalg.init_tensor [5, 3, 7]
// CHECK: %[[T1:.+]] = tensor.collapse_shape %[[T0]]
// CHECK-SAME: [0], [1, 2]
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP3]]]
// CHECK-SAME: ins(%[[ARG0]] : tensor<3x5x7xf32>)
// CHECK-SAME: outs(%[[T1]] : tensor<5x21xf32>)
// -----
#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
func.func @generic_op_reshape_consumer_nofusion(%arg0 : tensor<?x?x?x5xf32>,
%arg1 : tensor<?x?x?x5xf32>) ->
tensor<?x?xf32>
{
%0 = linalg.generic {
indexing_maps = [#map0, #map0, #map0],
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
ins(%arg0, %arg1 : tensor<?x?x?x5xf32>, tensor<?x?x?x5xf32>)
outs(%arg0 : tensor<?x?x?x5xf32>) {
^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
%1 = arith.mulf %arg3, %arg4 : f32
linalg.yield %1 : f32
} -> tensor<?x?x?x5xf32>
%1 = tensor.collapse_shape %0 [[0], [1, 2, 3]] :
tensor<?x?x?x5xf32> into tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
// CHECK-LABEL: func @generic_op_reshape_consumer_nofusion
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?x5xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?x?x5xf32>
// CHECK: %[[NOFUSE:.+]] = linalg.generic
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]]
// CHECK: %[[RESULT:.+]] = tensor.collapse_shape %[[NOFUSE]]
// CHECK: return %[[RESULT]]
// -----
func.func @generic_op_permultation_reshape_consumer_fusion_unused_dim(%arg0 : tensor<6x1xf32>) -> tensor<6xi32> {
%0 = linalg.init_tensor [6, 1] : tensor<6x1xi32>
%1 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%arg0 : tensor<6x1xf32>) outs(%0 : tensor<6x1xi32>) {
^bb0(%arg3: f32, %arg4: i32):
%5 = arith.fptosi %arg3 : f32 to i32
linalg.yield %5 : i32
} -> tensor<6x1xi32>
%6 = tensor.collapse_shape %1 [[0, 1]] : tensor<6x1xi32> into tensor<6xi32>
return %6 : tensor<6xi32>
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0)>
// CHECK: func @generic_op_permultation_reshape_consumer_fusion_unused_dim
// CHECK-SAME: %[[ARG0:.+]]: tensor<6x1xf32>
// CHECK: %[[T0:.+]] = linalg.init_tensor [6, 1]
// CHECK: %[[T1:.+]] = tensor.collapse_shape %[[T0]]
// CHECK-SAME: [0, 1]
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
// CHECK-SAME: ins(%[[ARG0]] : tensor<6x1xf32>)
// CHECK-SAME: outs(%[[T1]] : tensor<6xi32>)
// -----
#map0 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d2, d4, d0, d6, d3, d5, d1)>
#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4, d5, d6)>
func.func @permuted_dims_fusion_expand_shape(%arg0 : tensor<3x8x7x240xf32>) -> tensor<4x6x3x8x2x5x7xf32> {
%0 = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5, 6]]
: tensor<3x8x7x240xf32> into tensor<3x2x4x7x8x5x6xf32>
%1 = linalg.init_tensor [4, 6, 3, 8, 2, 5, 7] : tensor<4x6x3x8x2x5x7xf32>
%2 = linalg.generic {
indexing_maps = [#map0, #map1],
iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]}
ins(%0 : tensor<3x2x4x7x8x5x6xf32>) outs(%1 : tensor<4x6x3x8x2x5x7xf32>) {
^bb0(%arg1 : f32, %arg2 : f32):
linalg.yield %arg1 : f32
} -> tensor<4x6x3x8x2x5x7xf32>
return %2 : tensor<4x6x3x8x2x5x7xf32>
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d2, d0 + d4 * 4, d6, d1 + d3 * 30 + d5 * 6)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4, d5, d6)>
// CHECK: func @permuted_dims_fusion_expand_shape(
// CHECK-SAME: %[[ARG0:.+]]: tensor<3x8x7x240xf32>)
// CHECK: %[[RESULT:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
// CHECK-SAME: ins(%[[ARG0]] :
// CHECK: return %[[RESULT]]
// -----
#map0 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d2, d4, d0, d6, d3, d5, d1)>
#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4, d5, d6)>
func.func @permuted_dims_fusion_collapse_shape(%arg0 : tensor<4x6x3x8x2x5x7xf32>) -> tensor<3x8x7x240xf32> {
%0 = linalg.init_tensor [3, 2, 4, 7, 8, 5, 6] : tensor<3x2x4x7x8x5x6xf32>
%1 = linalg.generic {
indexing_maps = [#map1, #map0],
iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]}
ins(%arg0 : tensor<4x6x3x8x2x5x7xf32>) outs(%0 : tensor<3x2x4x7x8x5x6xf32>) {
^bb0(%arg1 : f32, %arg2 : f32):
linalg.yield %arg1 : f32
} -> tensor<3x2x4x7x8x5x6xf32>
%2 = tensor.collapse_shape %1 [[0], [1, 2], [3], [4, 5, 6]]
: tensor<3x2x4x7x8x5x6xf32> into tensor<3x8x7x240xf32>
return %2 : tensor<3x8x7x240xf32>
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4, d5, d6)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d2, d0 + d4 * 4, d6, d1 + d3 * 30 + d5 * 6)>
// CHECK: func @permuted_dims_fusion_collapse_shape(
// CHECK-SAME: %[[ARG0:.+]]: tensor<4x6x3x8x2x5x7xf32>)
// CHECK: %[[RESULT:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
// CHECK-SAME: ins(%[[ARG0]] :
// CHECK: return %[[RESULT]]

View File

@ -1,52 +0,0 @@
// RUN: mlir-opt -linalg-fold-reshape-ops-by-linearization=allow-folding-unit-dim-reshapes -split-input-file %s | FileCheck %s
#map = affine_map<(d0, d1) -> (d0, d1)>
func.func @do_not_fold1(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?x1xf32>
{
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
%1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
%2 = linalg.init_tensor [%0, %1] : tensor<?x?xf32>
%3 = linalg.generic {
indexing_maps = [#map, #map, #map],
iterator_types = ["parallel", "parallel"]}
ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%2 : tensor<?x?xf32>) {
^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32):
%4 = arith.addf %arg2, %arg3 : f32
linalg.yield %4 : f32
} -> tensor<?x?xf32>
%4 = tensor.expand_shape %3 [[0], [1, 2]] : tensor<?x?xf32> into tensor<?x?x1xf32>
return %4 : tensor<?x?x1xf32>
}
// CHECK-LABEL: func @do_not_fold1
// CHECK: %[[VAL:.+]] = linalg.generic
// CHECK: tensor.expand_shape %[[VAL]]
// -----
#map = affine_map<(d0, d1) -> (d0, d1)>
func.func @do_not_fold2(%arg0 : tensor<?x?x1xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32>
{
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%0 = tensor.collapse_shape %arg0 [[0], [1, 2]] : tensor<?x?x1xf32> into tensor<?x?xf32>
%1 = tensor.dim %arg1, %c0 : tensor<?x?xf32>
%2 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
%3 = linalg.init_tensor [%1, %2] : tensor<?x?xf32>
%4 = linalg.generic {
indexing_maps = [#map, #map, #map],
iterator_types = ["parallel", "parallel"]}
ins(%0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%3 : tensor<?x?xf32>) {
^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32):
%4 = arith.addf %arg2, %arg3 : f32
linalg.yield %4 : f32
} -> tensor<?x?xf32>
return %4 : tensor<?x?xf32>
}
// CHECK-LABEL: func @do_not_fold2
// CHECK: %[[VAL:.+]] = tensor.collapse_shape
// CHECK: linalg.generic
// CHECK-SAME: ins(%[[VAL]], %{{.+}} : tensor<?x?xf32>, tensor<?x?xf32>)

View File

@ -70,18 +70,18 @@ struct TestLinalgElementwiseFusion
llvm::cl::desc("Test fusion of generic operations."),
llvm::cl::init(false)};
Option<bool> fuseWithReshapeByExpansion{
*this, "fuse-with-reshape-by-expansion",
llvm::cl::desc(
"Test fusion of generic operations with reshape by expansion"),
llvm::cl::init(false)};
Option<bool> controlFuseByExpansion{
*this, "control-fusion-by-expansion",
llvm::cl::desc(
"Test controlling fusion of reshape with generic op by expansion"),
llvm::cl::init(false)};
Option<bool> pushExpandingReshape{
*this, "push-expanding-reshape",
llvm::cl::desc("Test linalg expand_shape -> generic "
"to generic -> expand_shape pattern"),
llvm::cl::init(false)};
Option<bool> fuseWithReshapeByCollapsing{
*this, "fuse-with-reshape-by-collapsing",
llvm::cl::desc("Test linalg expand_shape -> generic fusion patterns that "
@ -109,6 +109,17 @@ struct TestLinalgElementwiseFusion
return;
}
if (fuseWithReshapeByExpansion) {
RewritePatternSet fusionPatterns(context);
linalg::populateFoldReshapeOpsByExpansionPatterns(
fusionPatterns, [](const OpResult & /*producer*/,
OpOperand & /*consumer*/) { return true; });
if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
std::move(fusionPatterns))))
return signalPassFailure();
return;
}
if (controlFuseByExpansion) {
RewritePatternSet fusionPatterns(context);
@ -128,8 +139,9 @@ struct TestLinalgElementwiseFusion
if (linalgOp && linalgOp.isOutputTensor(&use))
return true;
}
return false;
}
return linalg::skipUnitDimReshape(producer, consumer);
return true;
};
linalg::populateFoldReshapeOpsByExpansionPatterns(fusionPatterns,
@ -139,12 +151,6 @@ struct TestLinalgElementwiseFusion
return;
}
if (pushExpandingReshape) {
RewritePatternSet patterns(context);
linalg::populatePushReshapeOpsPatterns(patterns);
(void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
}
if (fuseWithReshapeByCollapsing) {
RewritePatternSet patterns(context);
linalg::populateFoldReshapeOpsByCollapsingPatterns(