forked from OSchip/llvm-project
[mlir][Linalg] Deprecate legacy reshape + generic op folding patterns.
These patterns have been superceded by the fusion by collapsing patterns. Differential Revision: https://reviews.llvm.org/D124145
This commit is contained in:
parent
e8572aca0c
commit
0c090dcc8a
|
@ -90,31 +90,11 @@ def LinalgFoldUnitExtentDims : Pass<"linalg-fold-unit-extent-dims", ""> {
|
|||
def LinalgElementwiseOpFusion : Pass<"linalg-fuse-elementwise-ops"> {
|
||||
let summary = "Fuse elementwise operations on tensors";
|
||||
let constructor = "mlir::createLinalgElementwiseOpFusionPass()";
|
||||
let options = [
|
||||
Option<"allowFoldingUnitDimReshapes", "allow-folding-unit-dim-reshapes",
|
||||
"bool", /*default=*/"false",
|
||||
"Allow fusing linalg.tensor_reshape ops that performs unit "
|
||||
"dimension collapsing">
|
||||
];
|
||||
let dependentDialects = [
|
||||
"AffineDialect", "linalg::LinalgDialect", "memref::MemRefDialect"
|
||||
];
|
||||
}
|
||||
|
||||
def LinalgFoldReshapeOpsByLinearization :
|
||||
Pass<"linalg-fold-reshape-ops-by-linearization"> {
|
||||
let summary = "Fold TensorReshapeOps with generic/indexed generic ops by "
|
||||
"linearization";
|
||||
let constructor = "mlir::createFoldReshapeOpsByLinearizationPass()";
|
||||
let options = [
|
||||
Option<"allowFoldingUnitDimReshapes", "allow-folding-unit-dim-reshapes",
|
||||
"bool", /*default=*/"false",
|
||||
"Allow fusing linalg.tensor_reshape ops that performs unit "
|
||||
"dimension collapsing">
|
||||
];
|
||||
let dependentDialects = ["AffineDialect", "memref::MemRefDialect"];
|
||||
}
|
||||
|
||||
def LinalgNamedOpConversion: Pass<"linalg-named-op-conversion"> {
|
||||
let summary = "Convert from one named linalg op to another.";
|
||||
let constructor = "mlir::createLinalgNamedOpConversionPass()";
|
||||
|
|
|
@ -37,10 +37,6 @@ struct LinalgElementwiseFusionOptions;
|
|||
struct LinalgFusionOptions;
|
||||
struct LinalgTilingOptions;
|
||||
|
||||
/// Default function to control reshape folding. Skips folding unit dimension
|
||||
/// reshapes.
|
||||
bool skipUnitDimReshape(const OpResult &producer, OpOperand &consumer);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Transformations exposed as function calls.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -91,24 +87,6 @@ void populateFoldReshapeOpsByCollapsingPatterns(
|
|||
void populateConstantFoldLinalgOperations(RewritePatternSet &patterns,
|
||||
const ControlFusionFn &controlFn);
|
||||
|
||||
/// Patterns to fold a collapsing (expanding) tensor_reshape operation with its
|
||||
/// producer (consumer) generic operation by linearizing the indexing map used
|
||||
/// to access the source (target) of the reshape operation in the generic
|
||||
/// operation.
|
||||
/// TODO(ravishankarm): These patterns are to be deprecated in favor of using
|
||||
/// the `populateFoldReshapeByCollapsingPatterns`.
|
||||
void populateFoldReshapeOpsByLinearizationPatterns(RewritePatternSet &patterns);
|
||||
|
||||
/// Patterns to fold a collapsing (expanding) tensor_reshape operation with its
|
||||
/// producer (consumer) generic operation by linearizing the indexing map used
|
||||
/// to access the source (target) of the reshape operation in the generic
|
||||
/// operation. The patterns are applied only when the tensor reshape involved is
|
||||
/// collapsing (introducing) unit-extent dimensions.
|
||||
/// TODO(ravishankarm): These patterns are to be deprecated in favor of using
|
||||
/// the `populateFoldReshapeByCollapsingPatterns`.
|
||||
void populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
/// Pattern to fuse a `tensor.pad` operation with the producer of its source,
|
||||
/// if the producer is a `linalg` operation with all parallel iterator types.
|
||||
void populateFuseTensorPadWithProducerLinalgOpPatterns(
|
||||
|
@ -128,12 +106,6 @@ void populateInlineConstantOperandsPatterns(RewritePatternSet &patterns);
|
|||
/// Patterns that are used to bubble up extract slice op above linalg op.
|
||||
void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns);
|
||||
|
||||
/// Patterns to push reshape op towards the end of the graph in order to expose
|
||||
/// more fusion opportunities.
|
||||
/// TODO(ravishankarm): These patterns are to be deprecated in favor of using
|
||||
/// the `populateFoldReshapeByCollapsingPatterns`.
|
||||
void populatePushReshapeOpsPatterns(RewritePatternSet &patterns);
|
||||
|
||||
/// Perform standalone tiling of a single LinalgOp by `tileSizes`.
|
||||
/// and permute the loop nest according to `interchangeVector`
|
||||
/// The permutation is expressed as a list of integers that specify
|
||||
|
|
|
@ -392,263 +392,6 @@ private:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
//===---------------------------------------------------------------------===//
|
||||
// Methods and patterns that fuse reshape ops with elementwise operations by
|
||||
// linearization of indexing maps.
|
||||
//===---------------------------------------------------------------------===//
|
||||
|
||||
// TODO(ravishankarm): The indexing maps
|
||||
// these produce in the general case are detrimental to transformations.
|
||||
// These patterns are on deprecation path in favor of using fusion by
|
||||
// collapsing, which covers the only legitimate use case of this pattern of
|
||||
// folding unit-extent dims.
|
||||
|
||||
/// Linearize the expressions in `sourceMap` based on the `reassociationMaps`
|
||||
/// provided, given the shape of the source tensor that corresponds to the
|
||||
/// `sourceMap`. Note that this implicitly assumes that the tensors dimensions
|
||||
/// are "row-major" ordered logically.
|
||||
///
|
||||
/// For example:
|
||||
///
|
||||
/// %0 = op ... : tensor<?x?x4x5xf32>
|
||||
/// with output index_map `affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>`
|
||||
///
|
||||
/// and reshape:
|
||||
/// %1 = tensor.collapse_shape %0 [[0], [0, 1, 2]] :
|
||||
/// tensor<?x?x4x5xf32> into tensor<?x?xf32>
|
||||
///
|
||||
/// would be rewritten into:
|
||||
/// %0 = op ... : tensor<?x?x4x5xf32>
|
||||
/// with output index_map
|
||||
/// `affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>`
|
||||
template <typename TensorReshapeOp>
|
||||
static AffineMap linearizeCollapsedDims(AffineMap sourceMap,
|
||||
TensorReshapeOp reshapeOp) {
|
||||
constexpr bool isExpanding =
|
||||
std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value;
|
||||
ArrayRef<int64_t> sourceShape =
|
||||
(isExpanding ? reshapeOp.getResultType().getShape()
|
||||
: reshapeOp.getSrcType().getShape());
|
||||
SmallVector<AffineExpr> resultExprs;
|
||||
ArrayRef<AffineExpr> sourceExprs = sourceMap.getResults();
|
||||
MLIRContext *context = sourceMap.getContext();
|
||||
|
||||
// Compute the result exprs based on the reassociation maps.
|
||||
for (auto &indices : reshapeOp.getReassociationIndices()) {
|
||||
// Assume that they are in-order and contiguous (already checked in
|
||||
// verifier).
|
||||
assert(!indices.empty());
|
||||
SmallVector<int64_t> sizes;
|
||||
SmallVector<AffineExpr> dimExprs;
|
||||
for (auto en : llvm::zip(sourceShape.slice(indices[0], indices.size()),
|
||||
sourceExprs.slice(indices[0], indices.size()))) {
|
||||
if (std::get<0>(en) == 1)
|
||||
continue;
|
||||
sizes.push_back(std::get<0>(en));
|
||||
dimExprs.push_back(std::get<1>(en));
|
||||
}
|
||||
AffineExpr linearizedExpr =
|
||||
makeCanonicalStridedLayoutExpr(sizes, dimExprs, context);
|
||||
resultExprs.push_back(linearizedExpr);
|
||||
}
|
||||
// The new affine map cannot drop unused dimension but some new symbols may
|
||||
// have been added. Create a map with at least as many dimensions/symbols as
|
||||
// the original affine map.
|
||||
int64_t maxDim = -1;
|
||||
int64_t maxSym = -1;
|
||||
getMaxDimAndSymbol<SmallVector<AffineExpr>>({resultExprs}, maxDim, maxSym);
|
||||
unsigned numDims = std::max(unsigned(maxDim + 1), sourceMap.getNumDims());
|
||||
unsigned numSyms = std::max(unsigned(maxSym + 1), sourceMap.getNumSymbols());
|
||||
return AffineMap::get(numDims, numSyms, resultExprs, context);
|
||||
}
|
||||
|
||||
// tensor::ExpandShapeOp is fusable with its consumer (i.e. reshape as a
|
||||
// producer). Fusing when operand has higher rank will require use of mods and
|
||||
// divs in the indexing maps of the fused op which would make it non-invertible.
|
||||
static bool isTensorReshapeOpFoldableByLinearization(
|
||||
tensor::ExpandShapeOp expandOp, AffineMap useIndexMap, bool asProducer) {
|
||||
if (!asProducer)
|
||||
return false;
|
||||
return useIndexMap.isPermutation();
|
||||
}
|
||||
|
||||
// tensor::CollapseShapeOp is fusable with its producer (i.e. reshape as a
|
||||
// consumer).
|
||||
static bool
|
||||
isTensorReshapeOpFoldableByLinearization(tensor::CollapseShapeOp collapseOp,
|
||||
AffineMap useIndexMap,
|
||||
bool asProducer) {
|
||||
if (asProducer)
|
||||
return false;
|
||||
return useIndexMap.isPermutation();
|
||||
}
|
||||
|
||||
/// Check if the reshape operation is only expansion into/collapsing of
|
||||
/// unit-dimension.
|
||||
template <typename TensorReshapeOp>
|
||||
static bool isUnitDimExpansionOnly(TensorReshapeOp reshapeOp) {
|
||||
constexpr bool isExpanding =
|
||||
std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value;
|
||||
ArrayRef<int64_t> expandedShape =
|
||||
(isExpanding ? reshapeOp.getResultType().getShape()
|
||||
: reshapeOp.getSrcType().getShape());
|
||||
for (auto &indices : reshapeOp.getReassociationIndices()) {
|
||||
unsigned numUnitDims = 0;
|
||||
for (int64_t position : indices)
|
||||
if (expandedShape[position] == 1)
|
||||
numUnitDims++;
|
||||
if (numUnitDims != indices.size() - 1)
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// Pattern to fold tensor_expand_shape op with its consumer by using the source
|
||||
/// of the reshape op as the operand in the consumer (instead of the result of
|
||||
/// the tensor_collapse_shape). The corresponding index map in the consumer
|
||||
/// needs to be modified to linearize the folded dimension.
|
||||
///
|
||||
/// For example,
|
||||
///
|
||||
/// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
||||
/// %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]]
|
||||
/// tensor<?x?x?xf32> into tensor<?x?x4x?xf32>
|
||||
/// %1 = linalg.generic { indexing_maps = [#map0, #map0, #map0], ... }
|
||||
/// ins(%0, %arg1 : tensor<?x?x4x?xf32>, tensor<?x?x4x?xf32>) ...
|
||||
/// -> tensor<?x?x4x?xf32>
|
||||
///
|
||||
/// can be folded into
|
||||
///
|
||||
/// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)>
|
||||
/// #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
||||
/// %0 = linalg.generic { indexing_maps = [#map0, #map1, #map1] ... }
|
||||
/// ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x4x?xf32>) ...
|
||||
/// -> tensor<?x?x4x?xf32>
|
||||
template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp>
|
||||
struct FoldProducerReshapeOpByLinearization
|
||||
: public OpRewritePattern<GenericOp> {
|
||||
using OpRewritePattern<GenericOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(GenericOp genericOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (!genericOp.hasTensorSemantics())
|
||||
return failure();
|
||||
SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands();
|
||||
for (const auto &en : llvm::enumerate(inputOperands)) {
|
||||
auto reshapeOp = en.value()->get().getDefiningOp<TensorReshapeOp>();
|
||||
if (!reshapeOp)
|
||||
continue;
|
||||
|
||||
if (!isTensorReshapeOpFoldableByLinearization(
|
||||
reshapeOp, genericOp.getTiedIndexingMap(en.value()),
|
||||
/*asProducer =*/true) ||
|
||||
(foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp)))
|
||||
continue;
|
||||
|
||||
// Compute the fused operands list,
|
||||
SmallVector<Value> fusedOperands = genericOp.getInputOperands();
|
||||
fusedOperands[en.index()] = reshapeOp.src();
|
||||
SmallVector<Value> outputOperands = genericOp.getOutputOperands();
|
||||
llvm::append_range(fusedOperands, outputOperands);
|
||||
|
||||
// Compute indexing_maps for the fused operation. The indexing_maps for
|
||||
// the operands of the consumers that arent fused are the same.
|
||||
SmallVector<AffineMap> fusedIndexMaps = genericOp.getIndexingMaps();
|
||||
|
||||
// Compute the indexing map to use for the result of the producer.
|
||||
AffineMap modifiedMap =
|
||||
linearizeCollapsedDims(fusedIndexMaps[en.index()], reshapeOp);
|
||||
// The modified map cannot have symbols.
|
||||
if (modifiedMap.getNumSymbols())
|
||||
return failure();
|
||||
for (AffineExpr expr : modifiedMap.getResults()) {
|
||||
if (!expr.isPureAffine())
|
||||
return failure();
|
||||
}
|
||||
fusedIndexMaps[en.index()] = modifiedMap;
|
||||
|
||||
// Further check that the resulting index maps can be fused and
|
||||
// inverted. Without this the resultant op is not legal.
|
||||
if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
genericOp, "fused op loop bound computation failed");
|
||||
}
|
||||
|
||||
rewriter.startRootUpdate(genericOp);
|
||||
genericOp->setOperands(fusedOperands);
|
||||
genericOp.indexing_mapsAttr(
|
||||
rewriter.getAffineMapArrayAttr(fusedIndexMaps));
|
||||
rewriter.finalizeRootUpdate(genericOp);
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
/// Pattern to fold tensor_collapse_shape or tensor_expand_shape op with its
|
||||
/// producer. The corresponding index map in the consumer needs to be modified
|
||||
/// to linearize the folded dimension.
|
||||
template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp>
|
||||
struct FoldConsumerReshapeOpByLinearization
|
||||
: public OpRewritePattern<TensorReshapeOp> {
|
||||
using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
GenericOp producer = reshapeOp.src().template getDefiningOp<GenericOp>();
|
||||
if (!producer || !producer.hasTensorSemantics() ||
|
||||
producer.getNumOutputs() != 1 ||
|
||||
!isTensorReshapeOpFoldableByLinearization(
|
||||
reshapeOp,
|
||||
producer.getTiedIndexingMap(producer.getOutputOperand(0)),
|
||||
/*asProducer =*/false) ||
|
||||
(foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp)))
|
||||
return failure();
|
||||
// The indexing_maps for the operands of the fused operation are same as
|
||||
// those for the operands of the producer.
|
||||
SmallVector<AffineMap> fusedIndexMaps = producer.getIndexingMaps();
|
||||
|
||||
// Compute the indexing map to use for the operand of the producer.
|
||||
AffineMap modifiedMap = linearizeCollapsedDims(
|
||||
producer.getTiedIndexingMap(producer.getOutputOperand(0)), reshapeOp);
|
||||
for (AffineExpr expr : modifiedMap.getResults()) {
|
||||
if (!expr.isPureAffine()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
producer, "fused op indexing map is not affine");
|
||||
}
|
||||
}
|
||||
fusedIndexMaps.back() = modifiedMap;
|
||||
|
||||
// Further check that the resulting index maps can be fused and
|
||||
// inverted. Without this the resultant op is not legal.
|
||||
if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
producer, "fused op loop bound computation failed");
|
||||
}
|
||||
|
||||
Location loc = producer.getLoc();
|
||||
SmallVector<Value> inputOperands = producer.getInputOperands();
|
||||
Value output = rewriter.create<TensorReshapeOp>(
|
||||
loc, producer.getOutputOperand(0)->get(),
|
||||
reshapeOp.getReassociationExprs());
|
||||
auto fusedOp = rewriter.create<GenericOp>(
|
||||
loc, reshapeOp.getResultType(),
|
||||
/*inputs=*/inputOperands,
|
||||
// TODO: handle outputs.
|
||||
/*outputs=*/output, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
|
||||
producer.iterator_types(),
|
||||
/*doc=*/nullptr,
|
||||
/*library_call=*/nullptr);
|
||||
auto &fusedRegion = fusedOp->getRegion(0);
|
||||
rewriter.cloneRegionBefore(producer->getRegion(0), fusedRegion,
|
||||
fusedRegion.begin());
|
||||
rewriter.replaceOp(reshapeOp, fusedOp->getResults());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
//===---------------------------------------------------------------------===//
|
||||
// Methods and patterns that fuse reshape ops with elementwise operations by
|
||||
// expanding the dimensionality of the elementwise operations.
|
||||
|
@ -1737,174 +1480,6 @@ private:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
//===---------------------------------------------------------------------===//
|
||||
// Methods and patterns to convert tensor.expand_shape -> linalg.generic
|
||||
// into linalg.generic -> tensor.expand_shape, i.e. push the reshape down.
|
||||
//===---------------------------------------------------------------------===//
|
||||
|
||||
// TODO(ravishankarm): This pattern is to be deprecated in favor of fusion by
|
||||
// collapsing that provides a more general functionality. This pattern is very
|
||||
// specific to a particular use case. The fusion by collapsing can provide the
|
||||
// same control to clients using the control function there.
|
||||
|
||||
static SmallVector<ReassociationIndices>
|
||||
getReassociationIndices(ArrayRef<AffineMap> maps) {
|
||||
SmallVector<ReassociationIndices> reassociation;
|
||||
for (AffineMap map : maps) {
|
||||
ReassociationIndices indices;
|
||||
for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
|
||||
unsigned pos = map.getResult(i).cast<AffineDimExpr>().getPosition();
|
||||
indices.push_back(pos);
|
||||
}
|
||||
reassociation.push_back(indices);
|
||||
}
|
||||
return reassociation;
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// Pattern to move rank reducing reshape after an elementwise linalg generic
|
||||
/// op. This is useful to expose more fusion opportunities between named ops and
|
||||
/// generic ops. This can only be done if there is no broadcast or permuation
|
||||
/// within the dimensions we need to merge.
|
||||
///
|
||||
/// For example,
|
||||
///
|
||||
/// %0 = tensor.expand_shape %A [[0, 1], [2]]
|
||||
/// : tensor<12544x16xf32> into tensor<112x112x16xf32>
|
||||
/// %2 = linalg.generic {indexing_maps = [
|
||||
/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
|
||||
/// affine_map<(d0, d1, d2) -> (d2)>,
|
||||
/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types =
|
||||
/// ["parallel", "parallel", "parallel"]} {
|
||||
/// } -> tensor<112x112x16xf32>
|
||||
///
|
||||
/// into
|
||||
///
|
||||
/// %2 = linalg.generic {indexing_maps = [
|
||||
/// affine_map<(d0, d1) -> (d0, d1)>,
|
||||
/// affine_map<(d0, d1) -> (d1)>,
|
||||
/// affine_map<(d0, d1) -> (d0, d1)>],
|
||||
/// iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1
|
||||
/// : tensor<12544x16xf32>, tensor<16xf32>) outs(%1 : tensor<12544x16xf32>) {
|
||||
/// } -> tensor<12544x16xf32>
|
||||
/// %3 = tensor.expand_shape %2 [[0, 1], [2]]
|
||||
/// : tensor<12544x16xf32> into tensor<112x112x16xf32>
|
||||
struct PushExpandingReshape : public OpRewritePattern<GenericOp> {
|
||||
using OpRewritePattern<GenericOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(GenericOp genericOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Only apply to elementwise linalg on tensor.
|
||||
if (!genericOp.hasTensorSemantics() || genericOp.hasIndexSemantics() ||
|
||||
genericOp.getNumParallelLoops() != genericOp.getNumLoops())
|
||||
return failure();
|
||||
// Only support identity output maps. It could be extended to permuations if
|
||||
// needed.
|
||||
if (llvm::any_of(genericOp.getOutputOperands(), [&](OpOperand *opOperand) {
|
||||
return !genericOp.getTiedIndexingMap(opOperand).isIdentity();
|
||||
}))
|
||||
return failure();
|
||||
int64_t destRank = genericOp.getNumParallelLoops();
|
||||
SmallVector<Value> newOperands = genericOp.getInputOperands();
|
||||
tensor::ExpandShapeOp reshapeFound;
|
||||
// 1. Look for tensor_expand_shape operands and figure out save the
|
||||
// dimensions merged.
|
||||
SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands();
|
||||
for (const auto &en : llvm::enumerate(inputOperands)) {
|
||||
auto reshapeOp =
|
||||
en.value()->get().template getDefiningOp<tensor::ExpandShapeOp>();
|
||||
if (!reshapeOp)
|
||||
continue;
|
||||
// TODO: We could support non-identity map as long as the merged
|
||||
// dimensions are still contiguous.
|
||||
if (!genericOp.getTiedIndexingMap(en.value()).isIdentity())
|
||||
continue;
|
||||
if (reshapeFound) {
|
||||
// Only support a second reshape op if it has the same reassociate maps.
|
||||
if (reshapeFound.getReassociationMaps() ==
|
||||
reshapeOp.getReassociationMaps())
|
||||
newOperands[en.index()] = reshapeOp.src();
|
||||
continue;
|
||||
}
|
||||
reshapeFound = reshapeOp;
|
||||
newOperands[en.index()] = reshapeOp.src();
|
||||
}
|
||||
if (!reshapeFound)
|
||||
return failure();
|
||||
|
||||
// Calculate the reassociation indices and rassociated reverse map.
|
||||
SmallVector<ReassociationIndices> reassociation =
|
||||
getReassociationIndices(reshapeFound.getReassociationMaps());
|
||||
SmallVector<unsigned> remap(destRank);
|
||||
for (auto &indices : llvm::enumerate(reassociation)) {
|
||||
for (int64_t index : indices.value()) {
|
||||
remap[index] = indices.index();
|
||||
}
|
||||
}
|
||||
// 2. Verify that we can merge the dimensions in the linalg and that we
|
||||
// don't need to create new reshapes operands. Inserting new reshape
|
||||
// operands would defeat the purpose of the transformation.
|
||||
for (const auto &en : llvm::enumerate(inputOperands)) {
|
||||
if (en.value()->get() == newOperands[en.index()]) {
|
||||
AffineMap map = genericOp.getTiedIndexingMap(en.value());
|
||||
for (unsigned i : llvm::seq(unsigned(0), map.getNumResults())) {
|
||||
if (reassociation[remap[map.getDimPosition(i)]].size() > 1)
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Calculate the affine map remapping and the reassociation to apply to
|
||||
// output tensors.
|
||||
SmallVector<AffineMap> newMaps;
|
||||
unsigned newRank = reassociation.size();
|
||||
for (auto map : genericOp.getIndexingMaps()) {
|
||||
SmallVector<AffineExpr> newExprs;
|
||||
for (auto expr : map.getResults()) {
|
||||
unsigned position = expr.template cast<AffineDimExpr>().getPosition();
|
||||
// Skip dimension merged except for the last of the group.
|
||||
if (reassociation[remap[position]].back() == position) {
|
||||
newExprs.push_back(
|
||||
getAffineDimExpr(remap[position], genericOp.getContext()));
|
||||
}
|
||||
}
|
||||
newMaps.push_back(
|
||||
AffineMap::get(newRank, 0, newExprs, genericOp.getContext()));
|
||||
}
|
||||
|
||||
// 4. Reshape the output tensors.
|
||||
SmallVector<Value> newOutputs;
|
||||
SmallVector<Type> newOutputTypes;
|
||||
for (auto output : genericOp.outputs()) {
|
||||
auto newOutputType = RankedTensorType::get(
|
||||
reshapeFound.getSrcType().getShape(),
|
||||
output.getType().template cast<RankedTensorType>().getElementType());
|
||||
Value newOutput = rewriter.create<tensor::CollapseShapeOp>(
|
||||
genericOp->getLoc(), newOutputType, output, reassociation);
|
||||
newOutputTypes.push_back(newOutputType);
|
||||
newOutputs.push_back(newOutput);
|
||||
}
|
||||
// 5. Create a new generic op with lowerer rank.
|
||||
SmallVector<StringRef> iteratorTypes(newRank,
|
||||
getParallelIteratorTypeName());
|
||||
auto newOp = rewriter.create<GenericOp>(genericOp->getLoc(), newOutputTypes,
|
||||
newOperands, newOutputs, newMaps,
|
||||
iteratorTypes);
|
||||
rewriter.inlineRegionBefore(genericOp.region(), newOp.region(),
|
||||
newOp.region().begin());
|
||||
// 6. Reshape the so that the type matches the uses.
|
||||
SmallVector<Value> newResults;
|
||||
for (const auto &result : llvm::enumerate(newOp->getResults())) {
|
||||
newResults.push_back(rewriter.create<tensor::ExpandShapeOp>(
|
||||
genericOp->getLoc(), genericOp.getOutputTensorTypes()[result.index()],
|
||||
result.value(), reassociation));
|
||||
}
|
||||
rewriter.replaceOp(genericOp, newResults);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
//===---------------------------------------------------------------------===//
|
||||
// Methods and patterns that fuse constants with linalg.generic operations.
|
||||
//===---------------------------------------------------------------------===//
|
||||
|
@ -2093,27 +1668,6 @@ struct FoldFillWithGenericOp : public OpRewritePattern<GenericOp> {
|
|||
}
|
||||
};
|
||||
} // namespace
|
||||
//===---------------------------------------------------------------------===//
|
||||
// Methods that add patterns described in this file to a pattern list.
|
||||
//===---------------------------------------------------------------------===//
|
||||
|
||||
void mlir::linalg::populateFoldReshapeOpsByLinearizationPatterns(
|
||||
RewritePatternSet &patterns) {
|
||||
patterns.add<
|
||||
FoldProducerReshapeOpByLinearization<false, tensor::CollapseShapeOp>,
|
||||
FoldProducerReshapeOpByLinearization<false, tensor::ExpandShapeOp>,
|
||||
FoldConsumerReshapeOpByLinearization<false, tensor::CollapseShapeOp>>(
|
||||
patterns.getContext());
|
||||
}
|
||||
|
||||
void mlir::linalg::populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
|
||||
RewritePatternSet &patterns) {
|
||||
patterns
|
||||
.add<FoldProducerReshapeOpByLinearization<true, tensor::CollapseShapeOp>,
|
||||
FoldProducerReshapeOpByLinearization<true, tensor::ExpandShapeOp>,
|
||||
FoldConsumerReshapeOpByLinearization<true, tensor::CollapseShapeOp>>(
|
||||
patterns.getContext());
|
||||
}
|
||||
|
||||
void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
|
||||
RewritePatternSet &patterns,
|
||||
|
@ -2140,28 +1694,10 @@ void mlir::linalg::populateElementwiseOpsFusionPatterns(
|
|||
RemoveOutsDependency>(context);
|
||||
}
|
||||
|
||||
void mlir::linalg::populatePushReshapeOpsPatterns(RewritePatternSet &patterns) {
|
||||
auto *context = patterns.getContext();
|
||||
patterns.add<PushExpandingReshape>(context);
|
||||
}
|
||||
|
||||
//===---------------------------------------------------------------------===//
|
||||
// Passes
|
||||
//===---------------------------------------------------------------------===//
|
||||
|
||||
bool mlir::linalg::skipUnitDimReshape(const OpResult &producer,
|
||||
OpOperand &consumer) {
|
||||
if (auto producerCollapseOp =
|
||||
dyn_cast<tensor::CollapseShapeOp>(producer.getOwner())) {
|
||||
return !isUnitDimExpansionOnly(producerCollapseOp);
|
||||
}
|
||||
if (auto consumerExpandOp =
|
||||
dyn_cast<tensor::ExpandShapeOp>(consumer.getOwner())) {
|
||||
return !isUnitDimExpansionOnly(consumerExpandOp);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
/// Pass that fuses generic ops on tensors. Used only for testing.
|
||||
|
@ -2186,9 +1722,7 @@ struct LinalgElementwiseOpFusionPass
|
|||
// Add elementwise op fusion patterns.
|
||||
populateElementwiseOpsFusionPatterns(patterns, defaultControlFn);
|
||||
|
||||
populateFoldReshapeOpsByExpansionPatterns(
|
||||
patterns,
|
||||
allowFoldingUnitDimReshapes ? defaultControlFn : skipUnitDimReshape);
|
||||
populateFoldReshapeOpsByExpansionPatterns(patterns, defaultControlFn);
|
||||
|
||||
// Add the sparse tensor rewriting patterns.
|
||||
populateSparseTensorRewriting(patterns);
|
||||
|
@ -2212,27 +1746,8 @@ struct LinalgElementwiseOpFusionPass
|
|||
}
|
||||
};
|
||||
|
||||
/// Pass to test folding of reshape ops with generic ops by linearization.
|
||||
struct FoldReshapeOpsByLinearizationPass
|
||||
: public LinalgFoldReshapeOpsByLinearizationBase<
|
||||
FoldReshapeOpsByLinearizationPass> {
|
||||
void runOnOperation() override {
|
||||
Operation *op = getOperation();
|
||||
RewritePatternSet patterns(op->getContext());
|
||||
populateFoldReshapeOpsByLinearizationPatterns(patterns);
|
||||
if (allowFoldingUnitDimReshapes) {
|
||||
populateFoldUnitDimsReshapeOpsByLinearizationPatterns(patterns);
|
||||
}
|
||||
(void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> mlir::createLinalgElementwiseOpFusionPass() {
|
||||
return std::make_unique<LinalgElementwiseOpFusionPass>();
|
||||
}
|
||||
|
||||
std::unique_ptr<Pass> mlir::createFoldReshapeOpsByLinearizationPass() {
|
||||
return std::make_unique<FoldReshapeOpsByLinearizationPass>();
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=push-expanding-reshape -split-input-file | FileCheck %s
|
||||
// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=fuse-with-reshape-by-collapsing -split-input-file | FileCheck %s
|
||||
|
||||
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||
// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1) -> (d1)>
|
||||
|
@ -124,30 +124,3 @@ func.func @type_correctness(%arg0 : tensor<6x5xi32>, %arg1 : tensor<5xf32>,
|
|||
// CHECK-SAME: outs(%{{.+}} : tensor<6x5xf32>)
|
||||
// CHECK: tensor.expand_shape %[[OP]]
|
||||
// CHECK-SAME: tensor<6x5xf32> into tensor<2x3x5xf32>
|
||||
|
||||
// -----
|
||||
|
||||
func.func @generic_op_index_semantics(%A: tensor<?x16xi64>, %B: tensor<16xi64>, %init: tensor<?x112x16xi64>) -> tensor<?x112x16xi64> {
|
||||
%0 = tensor.expand_shape %A [[0, 1], [2]]
|
||||
: tensor<?x16xi64> into tensor<?x112x16xi64>
|
||||
%2 = linalg.generic {indexing_maps = [
|
||||
affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d2)>,
|
||||
affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
|
||||
iterator_types = ["parallel", "parallel", "parallel"]}
|
||||
ins(%0, %B : tensor<?x112x16xi64>, tensor<16xi64>)
|
||||
outs(%init : tensor<?x112x16xi64>) {
|
||||
^bb0(%arg1: i64, %arg2: i64, %arg3: i64): // no predecessors
|
||||
%index = linalg.index 0 : index
|
||||
%1 = arith.index_cast %index : index to i64
|
||||
%add = arith.addi %arg1, %1 : i64
|
||||
%s = arith.subi %add, %arg2 : i64
|
||||
linalg.yield %s : i64
|
||||
} -> tensor<?x112x16xi64>
|
||||
return %2 : tensor<?x112x16xi64>
|
||||
}
|
||||
// CHECK: func @generic_op_index_semantics
|
||||
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x16xi64>
|
||||
// CHECK: %[[RESHAPE:.+]] = tensor.expand_shape %[[ARG0]]
|
||||
// CHECK: %[[RESULT:.+]] = linalg.generic
|
||||
// CHECK-SAME: ins(%[[RESHAPE]]
|
||||
// CHECK: return %[[RESULT]]
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
// RUN: mlir-opt %s -linalg-fuse-elementwise-ops="allow-folding-unit-dim-reshapes=false" -split-input-file | FileCheck %s
|
||||
// RUN: mlir-opt %s -linalg-fuse-elementwise-ops="allow-folding-unit-dim-reshapes=true" -split-input-file | FileCheck %s --check-prefix=FOLDUNITDIM
|
||||
// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=fuse-with-reshape-by-expansion -split-input-file | FileCheck %s
|
||||
|
||||
#map0 = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
|
||||
#map1 = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
|
||||
#map2 = affine_map<(d0, d1, d2) -> ()>
|
||||
|
@ -14,7 +14,7 @@ func.func @generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xf32>,
|
|||
indexing_maps = [#map0, #map1, #map2, #map1],
|
||||
iterator_types = ["parallel", "parallel", "parallel"]}
|
||||
ins(%0, %arg1, %arg2 : tensor<?x?x?xf32>, tensor<?x?x?xf32>, f32)
|
||||
outs(%0 : tensor<?x?x?xf32>) {
|
||||
outs(%arg1 : tensor<?x?x?xf32>) {
|
||||
^bb0(%arg3: f32, %arg4: f32, %arg5: f32, %s: f32):
|
||||
%1 = arith.mulf %arg3, %arg4 : f32
|
||||
%2 = arith.addf %1, %arg5 : f32
|
||||
|
@ -30,15 +30,15 @@ func.func @generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xf32>,
|
|||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x4x?xf32>
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
|
||||
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: f32
|
||||
// CHECK: %[[T0:.+]] = tensor.collapse_shape %[[ARG0]]
|
||||
// CHECK-SAME: [0], [1, 2], [3]
|
||||
// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]]
|
||||
// CHECK-SAME: [0], [1], [2, 3]
|
||||
// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG1]]
|
||||
// CHECK-SAME: [0], [1], [2, 3]
|
||||
// CHECK: %[[T3:.+]] = linalg.generic
|
||||
// CHECK-SAME: indexing_maps = [#[[MAP5]], #[[MAP6]], #[[MAP7]], #[[MAP6]]]
|
||||
// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"]
|
||||
// CHECK-SAME: ins(%[[ARG0]], %[[T1]], %[[ARG2]] : tensor<?x?x4x?xf32>, tensor<?x?x?x4xf32>, f32)
|
||||
// CHECK-SAME: outs(%{{.+}} : tensor<?x?x?x4xf32>)
|
||||
// CHECK-SAME: outs(%[[T2]] : tensor<?x?x?x4xf32>)
|
||||
// CHECK: %[[T4:.+]] = tensor.collapse_shape %[[T3]]
|
||||
// CHECK-SAME: [0], [1], [2, 3]
|
||||
// CHECK-SAME: tensor<?x?x?x4xf32> into tensor<?x?x?xf32>
|
||||
|
@ -80,12 +80,14 @@ func.func @generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?xf32>,
|
|||
// CHECK-SAME: tensor<?x?xf32> into tensor<?x4x?x5xf32>
|
||||
// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]]
|
||||
// CHECK-SAME: [0], [1, 2, 3]
|
||||
// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG0]]
|
||||
// CHECK-SAME: [0], [1, 2, 3]
|
||||
// CHECK-SAME: tensor<?x?xf32> into tensor<?x4x?x5xf32>
|
||||
// CHECK: %[[T3:.+]] = linalg.generic
|
||||
// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP3]], #[[MAP2]]]
|
||||
// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"]
|
||||
// CHECK-SAME: ins(%[[T0]], %[[T1]], %[[ARG2]] : tensor<?x4x?x5xf32>, tensor<?x4x?x5xf32>, f32)
|
||||
// CHECK-SAME: outs(%{{.+}} : tensor<?x4x?x5xf32>)
|
||||
// CHECK-SAME: outs(%[[T2]] : tensor<?x4x?x5xf32>)
|
||||
// CHECK: return %[[T3]] : tensor<?x4x?x5xf32>
|
||||
|
||||
|
||||
|
@ -121,11 +123,14 @@ func.func @reshape_as_consumer_permutation
|
|||
// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]]
|
||||
// CHECK-SAME: [0, 1, 2], [3]
|
||||
// CHECK-SAME: tensor<?x?xf32> into tensor<3x4x?x?xf32>
|
||||
// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG0]]
|
||||
// CHECK-SAME: [0, 1], [2], [3, 4, 5]]
|
||||
// CHECK-SAME: tensor<?x?x?xf32> into tensor<?x2x?x3x4x?xf32>
|
||||
// CHECK: %[[T3:.+]] = linalg.generic
|
||||
// CHECK-SAME: indexing_maps = [#[[MAP8]], #[[MAP9]], #[[MAP10]]]
|
||||
// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]
|
||||
// CHECK-SAME: ins(%[[T0]], %[[T1]] : tensor<3x4x?x?x2x?xf32>, tensor<3x4x?x?xf32>)
|
||||
// CHECK-SAME: outs(%{{.+}} : tensor<?x2x?x3x4x?xf32>)
|
||||
// CHECK-SAME: outs(%[[T2]] : tensor<?x2x?x3x4x?xf32>)
|
||||
// CHECK: return %[[T3]] : tensor<?x2x?x3x4x?xf32>
|
||||
|
||||
// -----
|
||||
|
@ -155,14 +160,19 @@ func.func @generic_op_reshape_consumer_static(%arg0: tensor<264x4xf32>)
|
|||
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
// CHECK: func @generic_op_reshape_consumer_static
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<264x4xf32>
|
||||
// CHECK-DAG: %[[CST:.+]] = arith.constant
|
||||
// CHECK-SAME: : tensor<8x33x4xf32>
|
||||
// CHECK-DAG: %[[INIT:.+]] = linalg.init_tensor [264, 4]
|
||||
// CHECK: %[[T0:.+]] = tensor.expand_shape %[[ARG0]]
|
||||
// CHECK-SAME: [0, 1], [2]
|
||||
// CHECK-SAME: tensor<264x4xf32> into tensor<8x33x4xf32>
|
||||
// CHECK: %[[T1:.+]] = linalg.init_tensor [8, 33, 4]
|
||||
// CHECK: %[[T1:.+]] = tensor.expand_shape %[[INIT]]
|
||||
// CHECK-SAME: [0, 1], [2]
|
||||
// CHECK-SAME: : tensor<264x4xf32> into tensor<8x33x4xf32>
|
||||
// CHECK: %[[T2:.+]] = linalg.generic
|
||||
// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]]]
|
||||
// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP2]]]
|
||||
// CHECK-SAME: ["parallel", "parallel", "parallel"]
|
||||
// CHECK-SAME: ins(%[[T0]] : tensor<8x33x4xf32>)
|
||||
// CHECK-SAME: ins(%[[T0]], %[[CST]] :
|
||||
// CHECK-SAME: outs(%[[T1]] : tensor<8x33x4xf32>)
|
||||
// CHECK: return %[[T2]] : tensor<8x33x4xf32>
|
||||
|
||||
|
@ -246,7 +256,8 @@ func.func @indexed_producer_reshape_consumer_fusion(%arg0 : tensor<?x?xi32>,
|
|||
}
|
||||
|
||||
// Only check the body in the indexed version of the test.
|
||||
// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0 + d1 * 5 + d2 * 20)>
|
||||
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 4)>
|
||||
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 5)>
|
||||
// CHECK: func @indexed_producer_reshape_consumer_fusion
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: ^{{.*}}(
|
||||
|
@ -256,11 +267,12 @@ func.func @indexed_producer_reshape_consumer_fusion(%arg0 : tensor<?x?xi32>,
|
|||
// CHECK-DAG: %[[IDX1:.+]] = linalg.index 1 : index
|
||||
// CHECK-DAG: %[[IDX2:.+]] = linalg.index 2 : index
|
||||
// CHECK-DAG: %[[IDX3:.+]] = linalg.index 3 : index
|
||||
// CHECK-DAG: %[[T3:.+]] = affine.apply #[[MAP]](%[[IDX3]], %[[IDX2]], %[[IDX1]])
|
||||
// CHECK: %[[T1:.+]] = affine.apply #[[MAP1]](%[[IDX2]], %[[IDX1]])
|
||||
// CHECK: %[[T2:.+]] = affine.apply #[[MAP2]](%[[IDX3]], %[[T1]])
|
||||
// CHECK: %[[T4:.+]] = arith.muli %[[ARG3]], %[[ARG4]]
|
||||
// CHECK: %[[T5:.+]] = arith.index_cast %[[IDX0]]
|
||||
// CHECK: %[[T6:.+]] = arith.addi %[[T4]], %[[T5]]
|
||||
// CHECK: %[[T7:.+]] = arith.index_cast %[[T3]]
|
||||
// CHECK: %[[T7:.+]] = arith.index_cast %[[T2]]
|
||||
// CHECK: %[[T8:.+]] = arith.addi %[[T6]], %[[T7]]
|
||||
// CHECK: linalg.yield %[[T8]]
|
||||
|
||||
|
@ -295,24 +307,29 @@ func.func @reshape_as_consumer_permutation
|
|||
return %d : tensor<2x3x4x5x6x7xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d0, d1, d5)>
|
||||
// CHECK-DAG: #[[MAP6:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)>
|
||||
// CHECK-DAG: #[[MAP7:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d2, d3, d4)>
|
||||
// CHECK-DAG: #[[MAP8:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 3)>
|
||||
// CHECK-DAG: #[[MAP9:.+]] = affine_map<(d0, d1, d2) -> (d0 + d1 * 7 + d2 * 42)>
|
||||
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d0, d1, d5)>
|
||||
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)>
|
||||
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d2, d3, d4)>
|
||||
// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 3)>
|
||||
// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 6)>
|
||||
// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 7)>
|
||||
// CHECK: func @reshape_as_consumer_permutation
|
||||
// CHECK-SAME: %[[ARG0:.+]]: tensor<210x6x4xi32>
|
||||
// CHECK-SAME: %[[ARG1:.+]]: tensor<210x4xi32>
|
||||
// CHECK-DAG: %[[INIT:.+]] = linalg.init_tensor [6, 4, 210]
|
||||
// CHECK-DAG: %[[T1:.+]] = tensor.expand_shape %[[ARG0]]
|
||||
// CHECK-SAME: [0, 1, 2], [3, 4], [5]
|
||||
// CHECK-DAG: %[[T2:.+]] = tensor.expand_shape %[[ARG1]]
|
||||
// CHECK-SAME: [0, 1, 2], [3]
|
||||
// CHECK-DAG: %[[T0:.+]] = linalg.init_tensor [2, 3, 4, 5, 6, 7]
|
||||
// CHECK-DAG: %[[T3:.+]] = tensor.expand_shape %[[INIT]]
|
||||
// CHECK-SAME: [0, 1], [2], [3, 4, 5]
|
||||
// CHECK-SAME: : tensor<6x4x210xi32> into tensor<2x3x4x5x6x7xi32>
|
||||
// CHECK: %[[T4:.+]] = linalg.generic
|
||||
// CHECK-SAME: indexing_maps = [#[[MAP5]], #[[MAP6]], #[[MAP7]]]
|
||||
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
|
||||
// CHECK-SAME: ins(%[[T1]], %[[T2]] : tensor<5x6x7x2x3x4xi32>, tensor<5x6x7x4xi32>)
|
||||
// CHECK-SAME: outs(%[[T0]] : tensor<2x3x4x5x6x7xi32>)
|
||||
// CHECK-SAME: outs(%[[T3]] : tensor<2x3x4x5x6x7xi32>)
|
||||
// CHECK: ^{{.+}}(
|
||||
// CHECK-SAME: %[[ARG8:[a-zA-Z0-9]+]]: i32, %[[ARG9:[a-zA-Z0-9]+]]: i32,
|
||||
// CHECK-SAME: %[[ARG10:[a-zA-Z0-9]+]]: i32)
|
||||
|
@ -322,15 +339,16 @@ func.func @reshape_as_consumer_permutation
|
|||
// CHECK-DAG: %[[IDX3:.+]] = linalg.index 3 : index
|
||||
// CHECK-DAG: %[[IDX4:.+]] = linalg.index 4 : index
|
||||
// CHECK-DAG: %[[IDX5:.+]] = linalg.index 5 : index
|
||||
// CHECK-DAG: %[[T5:.+]] = affine.apply #[[MAP8]](%[[IDX1]], %[[IDX0]])
|
||||
// CHECK-DAG: %[[T6:.+]] = affine.apply #[[MAP9]](%[[IDX4]], %[[IDX3]], %[[IDX2]])
|
||||
// CHECK-DAG: %[[T7:.+]] = arith.addi %[[ARG8]], %[[ARG9]]
|
||||
// CHECK: %[[T8:.+]] = arith.index_cast %[[T5]]
|
||||
// CHECK: %[[T9:.+]] = arith.addi %[[T7]], %[[T8]]
|
||||
// CHECK: %[[T10:.+]] = arith.index_cast %[[T6]]
|
||||
// CHECK: %[[T11:.+]] = arith.addi %[[T9]], %[[T10]]
|
||||
// CHECK: %[[T12:.+]] = arith.index_cast %[[IDX5]]
|
||||
// CHECK: %[[T13:.+]] = arith.addi %[[T11]], %[[T12]]
|
||||
// CHECK-DAG: %[[T5:.+]] = affine.apply #[[MAP3]](%[[IDX1]], %[[IDX0]])
|
||||
// CHECK-DAG: %[[T6:.+]] = affine.apply #[[MAP4]](%[[IDX3]], %[[IDX2]])
|
||||
// CHECK-DAG: %[[T7:.+]] = affine.apply #[[MAP5]](%[[IDX4]], %[[T6]])
|
||||
// CHECK-DAG: %[[T8:.+]] = arith.addi %[[ARG8]], %[[ARG9]]
|
||||
// CHECK: %[[T9:.+]] = arith.index_cast %[[T5]]
|
||||
// CHECK: %[[T10:.+]] = arith.addi %[[T8]], %[[T9]]
|
||||
// CHECK: %[[T11:.+]] = arith.index_cast %[[T7]]
|
||||
// CHECK: %[[T12:.+]] = arith.addi %[[T10]], %[[T11]]
|
||||
// CHECK: %[[T13:.+]] = arith.index_cast %[[IDX5]]
|
||||
// CHECK: %[[T14:.+]] = arith.addi %[[T12]], %[[T13]]
|
||||
|
||||
// -----
|
||||
|
||||
|
@ -421,94 +439,18 @@ func.func @generic_op_reshape_consumer_fusion_projected(%arg0 : tensor<?x?xf32>,
|
|||
// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]]
|
||||
// CHECK-SAME: [0, 1, 2], [3]
|
||||
// CHECK-SAME: tensor<?x?xf32> into tensor<?x4x5x?xf32>
|
||||
// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG0]]
|
||||
// CHECK-SAME: [0], [1, 2, 3]
|
||||
// CHECK-SAME: tensor<?x?xf32> into tensor<?x?x4x5xf32>
|
||||
// CHECK: %[[T3:.+]] = linalg.generic
|
||||
// CHECK-SAME: indexing_maps = [#[[MAP4]], #[[MAP4]], #[[MAP5]]]
|
||||
// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"]
|
||||
// CHECK-SAME: ins(%[[T0]], %[[T1]] : tensor<?x4x5x?xf32>, tensor<?x4x5x?xf32>)
|
||||
// CHECK-SAME: outs(%{{.+}} : tensor<?x?x4x5xf32>)
|
||||
// CHECK-SAME: outs(%[[T2]] : tensor<?x?x4x5xf32>)
|
||||
// CHECK: return %[[T3]] : tensor<?x?x4x5xf32>
|
||||
|
||||
// -----
|
||||
|
||||
func.func @unit_dim_reshape_expansion(%arg0 : tensor<1x5xf32>) -> tensor<5x5xf32> {
|
||||
%0 = tensor.collapse_shape %arg0 [[0, 1]]
|
||||
: tensor<1x5xf32> into tensor<5xf32>
|
||||
%1 = linalg.init_tensor [5, 5] : tensor<5x5xf32>
|
||||
%2 = linalg.generic
|
||||
{indexing_maps = [affine_map<(d0, d1) -> (d0)>,
|
||||
affine_map<(d0, d1) -> (d0, d1)>],
|
||||
iterator_types = ["parallel", "parallel"]}
|
||||
ins(%0 : tensor<5xf32>) outs(%1 : tensor<5x5xf32>) {
|
||||
^bb0(%arg2: f32, %arg3: f32):
|
||||
linalg.yield %arg2 : f32
|
||||
} -> tensor<5x5xf32>
|
||||
return %2 : tensor<5x5xf32>
|
||||
}
|
||||
// CHECK: func @unit_dim_reshape_expansion
|
||||
// CHECK-DAG: tensor.collapse_shape
|
||||
// CHECK-DAG: linalg.init_tensor
|
||||
// CHECK: linalg.generic
|
||||
|
||||
// -----
|
||||
|
||||
func.func @unit_dim_reshape_collapse(%arg0 : tensor<5xf32>) -> tensor<5x1x5xf32> {
|
||||
%0 = linalg.init_tensor [5, 5] : tensor<5x5xf32>
|
||||
%1 = linalg.generic
|
||||
{indexing_maps = [affine_map<(d0, d1) -> (d0)>,
|
||||
affine_map<(d0, d1) -> (d0, d1)>],
|
||||
iterator_types = ["parallel", "parallel"]}
|
||||
ins(%arg0 : tensor<5xf32>) outs(%0 : tensor<5x5xf32>) {
|
||||
^bb0(%arg2: f32, %arg3: f32):
|
||||
linalg.yield %arg2 : f32
|
||||
} -> tensor<5x5xf32>
|
||||
%2 = tensor.expand_shape %1 [[0, 1], [2]]
|
||||
: tensor<5x5xf32> into tensor<5x1x5xf32>
|
||||
return %2 : tensor<5x1x5xf32>
|
||||
}
|
||||
// CHECK: func @unit_dim_reshape_collapse
|
||||
// CHECK: linalg.init_tensor
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: tensor.expand_shape
|
||||
|
||||
// -----
|
||||
|
||||
func.func @unit_dim_reshape_expansion_full
|
||||
(%arg0 : tensor<1x?x1x2x1x4xf32>, %arg1 : tensor<?x2x4xf32>)
|
||||
-> tensor<?x2x4xf32> {
|
||||
%c1 = arith.constant 1 : index
|
||||
%0 = tensor.collapse_shape %arg0 [[0, 1, 2], [3, 4], [5]]
|
||||
: tensor<1x?x1x2x1x4xf32> into tensor<?x2x4xf32>
|
||||
%1 = tensor.dim %arg0, %c1 : tensor<1x?x1x2x1x4xf32>
|
||||
%2 = linalg.init_tensor [%1, 2, 4] : tensor<?x2x4xf32>
|
||||
%3 = linalg.generic
|
||||
{indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
|
||||
affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
|
||||
affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
|
||||
iterator_types = ["parallel", "parallel", "parallel"]}
|
||||
ins(%0, %arg1 : tensor<?x2x4xf32>, tensor<?x2x4xf32>)
|
||||
outs(%2 : tensor<?x2x4xf32>) {
|
||||
^bb0(%arg2: f32, %arg3: f32, %arg4: f32):
|
||||
%4 = arith.mulf %arg2, %arg3 : f32
|
||||
linalg.yield %4 : f32
|
||||
} -> tensor<?x2x4xf32>
|
||||
return %3 : tensor<?x2x4xf32>
|
||||
}
|
||||
// CHECK: func @unit_dim_reshape_expansion_full
|
||||
// CHECK-DAG: tensor.collapse_shape
|
||||
// CHECK-DAG: linalg.init_tensor
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<?x2x4xf32>, tensor<?x2x4xf32>)
|
||||
|
||||
// FOLDUNITDIM: func @unit_dim_reshape_expansion_full
|
||||
// FOLDUNITDIM-SAME: %[[ARG0:.+]]: tensor<1x?x1x2x1x4xf32>
|
||||
// FOLDUNITDIM-SAME: %[[ARG1:.+]]: tensor<?x2x4xf32>
|
||||
// FOLDUNITDIM-DAG: %[[RESHAPE:.+]] = tensor.expand_shape %[[ARG1]]
|
||||
// FOLDUNITDIM: linalg.generic
|
||||
// FOLDUNITDIM-SAME: ins(%[[ARG0]], %[[RESHAPE]] : tensor<1x?x1x2x1x4xf32>, tensor<1x?x1x2x1x4xf32>)
|
||||
// FOLDUNITDIM-SAME: outs(%{{.+}} : tensor<1x?x1x2x1x4xf32>)
|
||||
|
||||
// -----
|
||||
|
||||
func.func @no_fuse_dynamic_dims(%arg0: tensor<?x?xf32>) -> tensor<?xf32> {
|
||||
%c0 = arith.constant 0 : index
|
||||
%0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<?x?xf32> into tensor<?xf32>
|
||||
|
@ -554,7 +496,6 @@ func.func @no_fuse_mismatched_dynamism(%arg0: tensor<2x1xi64>, %arg1: tensor<?xi
|
|||
// CHECK-SAME: %[[ARG0:.+]]: tensor<2x1xi64>
|
||||
// CHECK-SAME: %[[ARG1:.+]]: tensor<?xi64>
|
||||
// CHECK: %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]]
|
||||
// CHECK: %[[CAST:.+]] = tensor.cast %[[ARG1]] : tensor<?xi64> to tensor<2xi64>
|
||||
// CHECK: %[[GENERIC:.+]] = linalg.generic
|
||||
// CHECK-SAME: ins(%[[RESHAPE]], %[[CAST]] : tensor<2xi64>, tensor<2xi64>)
|
||||
// CHECK-SAME: ins(%[[RESHAPE]], %[[ARG1]] : tensor<2xi64>, tensor<?xi64>)
|
||||
// CHECK: return %[[GENERIC]]
|
||||
|
|
|
@ -1,287 +0,0 @@
|
|||
// RUN: mlir-opt -split-input-file -linalg-fold-reshape-ops-by-linearization %s | FileCheck %s
|
||||
|
||||
// Note: These tests fuse the reshape ops by linearization. This can create
|
||||
// indexing maps which are hard to analyse later on. These patterns are useful
|
||||
// only if the folded dimensions in the reshape op are unit extent. Tests here
|
||||
// are more general for testing purposes, but use of these pattern for non-unit
|
||||
// dimensions should be deprecated.
|
||||
|
||||
#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
||||
func.func @generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x?xi32>)
|
||||
-> tensor<?x?x4x?xi32> {
|
||||
%0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]] :
|
||||
tensor<?x?x?xi32> into tensor<?x?x4x?xi32>
|
||||
%1 = linalg.generic {
|
||||
indexing_maps = [#map0, #map0],
|
||||
iterator_types = ["parallel", "parallel", "parallel", "parallel"] }
|
||||
ins(%0 : tensor<?x?x4x?xi32>)
|
||||
outs(%0 : tensor<?x?x4x?xi32>) {
|
||||
^bb0(%arg6: i32, %arg7 : i32):
|
||||
%idx = linalg.index 0 : index
|
||||
%2 = arith.index_cast %idx : index to i32
|
||||
%3 = arith.addi %arg6, %2 : i32
|
||||
linalg.yield %3 : i32
|
||||
} -> tensor<?x?x4x?xi32>
|
||||
return %1 : tensor<?x?x4x?xi32>
|
||||
}
|
||||
// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)>
|
||||
// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
||||
// CHECK: func @generic_op_reshape_producer_fusion
|
||||
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xi32>
|
||||
// CHECK: %[[T0:.+]] = tensor.expand_shape %[[ARG0]]
|
||||
// CHECK-SAME: [0], [1, 2], [3]
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-SAME: indexing_maps = [#[[MAP3]], #[[MAP4]]]
|
||||
// CHECK-SAME: ins(%[[ARG0]] : tensor<?x?x?xi32>)
|
||||
// CHECK-SAME: outs(%[[T0]] : tensor<?x?x4x?xi32>)
|
||||
// CHECK: %[[IDX:.+]] = linalg.index 0 : index
|
||||
// CHECK-NEXT: %[[IDX_CASTED:.+]] = arith.index_cast %[[IDX]] : index to i32
|
||||
|
||||
// -----
|
||||
|
||||
#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
||||
func.func @generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?x4x5xi32>)
|
||||
-> tensor<?x?xi32> {
|
||||
%0 = linalg.generic {
|
||||
indexing_maps = [#map0, #map0],
|
||||
iterator_types = ["parallel", "parallel", "parallel", "parallel"] }
|
||||
ins(%arg0 : tensor<?x?x4x5xi32>) outs(%arg0 : tensor<?x?x4x5xi32>) {
|
||||
^bb0(%arg6: i32, %arg7: i32):
|
||||
%idx = linalg.index 0 : index
|
||||
%2 = arith.index_cast %idx : index to i32
|
||||
%3 = arith.addi %arg6, %2 : i32
|
||||
linalg.yield %3 : i32
|
||||
} -> tensor<?x?x4x5xi32>
|
||||
%1 = tensor.collapse_shape %0 [[0], [1, 2, 3]] :
|
||||
tensor<?x?x4x5xi32> into tensor<?x?xi32>
|
||||
return %1 : tensor<?x?xi32>
|
||||
}
|
||||
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
||||
// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>
|
||||
// CHECK: func @generic_op_reshape_consumer_fusion
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x4x5xi32>
|
||||
// CHECK: %[[T0:.+]] = tensor.collapse_shape %[[ARG0]]
|
||||
// CHECK-SAME: [0], [1, 2, 3]
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP3]]]
|
||||
// CHECK-SAME: outs(%[[T0]] : tensor<?x?xi32>)
|
||||
// CHECK: %[[IDX:.+]] = linalg.index 0 : index
|
||||
// CHECK-NEXT: %[[IDX_CASTED:.+]] = arith.index_cast %[[IDX]] : index to i32
|
||||
// CHECK-NOT: tensor.collapse_shape
|
||||
|
||||
// -----
|
||||
|
||||
#map2 = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
|
||||
#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
func.func @generic_op_021_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<3x7x5xf32> {
|
||||
%0 = tensor.expand_shape %arg0 [[0], [1, 2]]
|
||||
: tensor<3x35xf32> into tensor<3x5x7xf32>
|
||||
%1 = linalg.init_tensor [3, 7, 5] : tensor<3x7x5xf32>
|
||||
%2 = linalg.generic
|
||||
{indexing_maps = [#map2, #map3],
|
||||
iterator_types = ["parallel", "parallel", "parallel"]}
|
||||
ins(%0 : tensor<3x5x7xf32>) outs(%1 : tensor<3x7x5xf32>) {
|
||||
^bb0(%arg2: f32, %arg3 : f32):
|
||||
linalg.yield %arg2 : f32
|
||||
} -> tensor<3x7x5xf32>
|
||||
return %2 : tensor<3x7x5xf32>
|
||||
}
|
||||
|
||||
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1 + d2 * 7)>
|
||||
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
// CHECK: func @generic_op_021_permultation_reshape_producer_fusion
|
||||
// CHECK-NOT: tensor.expand_shape
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
|
||||
|
||||
// -----
|
||||
|
||||
#map2 = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
|
||||
#map3 = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
|
||||
func.func @generic_op_120_permutation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<5x7x3xf32> {
|
||||
%0 = tensor.expand_shape %arg0 [[0], [1, 2]]
|
||||
: tensor<3x35xf32> into tensor<3x5x7xf32>
|
||||
%1 = linalg.init_tensor [5, 7, 3] : tensor<5x7x3xf32>
|
||||
%2 = linalg.generic
|
||||
{indexing_maps = [#map2, #map3],
|
||||
iterator_types = ["parallel", "parallel", "parallel"]}
|
||||
ins(%0 : tensor<3x5x7xf32>) outs(%1 : tensor<5x7x3xf32>) {
|
||||
^bb0(%arg2: f32, %arg3: f32):
|
||||
linalg.yield %arg2 : f32
|
||||
} -> tensor<5x7x3xf32>
|
||||
return %2 : tensor<5x7x3xf32>
|
||||
}
|
||||
|
||||
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d1, d0 * 7 + d2)>
|
||||
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
|
||||
// CHECK: func @generic_op_120_permutation_reshape_producer_fusion
|
||||
// CHECK-NOT: tensor.expand_shape
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
|
||||
|
||||
// -----
|
||||
|
||||
#map0 = affine_map<(d0, d1, d2) -> (d0)>
|
||||
#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
|
||||
#map2 = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
|
||||
#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
func.func @generic_op_102_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<5x3x7xf32> {
|
||||
%0 = tensor.expand_shape %arg0 [[0], [1, 2]]
|
||||
: tensor<3x35xf32> into tensor<3x5x7xf32>
|
||||
%1 = linalg.init_tensor [5, 3, 7] : tensor<5x3x7xf32>
|
||||
%2 = linalg.generic
|
||||
{indexing_maps = [#map2, #map3],
|
||||
iterator_types = ["parallel", "parallel", "parallel"]}
|
||||
ins(%0 : tensor<3x5x7xf32>) outs(%1 : tensor<5x3x7xf32>) {
|
||||
^bb0(%arg2: f32, %arg3: f32):
|
||||
linalg.yield %arg2 : f32
|
||||
} -> tensor<5x3x7xf32>
|
||||
return %2 : tensor<5x3x7xf32>
|
||||
}
|
||||
|
||||
|
||||
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d1, d0 * 7 + d2)>
|
||||
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
// CHECK: func @generic_op_102_permultation_reshape_producer_fusion
|
||||
// CHECK-NOT: tensor.expand_shape
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
|
||||
|
||||
// -----
|
||||
|
||||
#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
#map1 = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
|
||||
#map2 = affine_map<(d0, d1, d2) -> (d0)>
|
||||
#map3 = affine_map<(d0, d1, d2) -> (d1, d2)>
|
||||
func.func @generic_op_102_permultation_reshape_consumer_fusion(%arg0 : tensor<3x5x7xf32>) -> tensor<5x21xf32> {
|
||||
%0 = linalg.init_tensor [5, 3, 7] : tensor<5x3x7xf32>
|
||||
%1 = linalg.generic
|
||||
{indexing_maps = [#map0, #map1],
|
||||
iterator_types = ["parallel", "parallel", "parallel"]}
|
||||
ins(%arg0 : tensor<3x5x7xf32>) outs(%0 : tensor<5x3x7xf32>) {
|
||||
^bb0(%arg2: f32, %arg3 : f32):
|
||||
linalg.yield %arg2 : f32
|
||||
} -> tensor<5x3x7xf32>
|
||||
%2 = tensor.collapse_shape %1 [[0], [1, 2]]
|
||||
: tensor<5x3x7xf32> into tensor<5x21xf32>
|
||||
return %2 : tensor<5x21xf32>
|
||||
}
|
||||
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2) -> (d1, d0 * 7 + d2)>
|
||||
// CHECK: func @generic_op_102_permultation_reshape_consumer_fusion
|
||||
// CHECK-SAME: %[[ARG0:.+]]: tensor<3x5x7xf32>
|
||||
// CHECK: %[[T0:.+]] = linalg.init_tensor [5, 3, 7]
|
||||
// CHECK: %[[T1:.+]] = tensor.collapse_shape %[[T0]]
|
||||
// CHECK-SAME: [0], [1, 2]
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP3]]]
|
||||
// CHECK-SAME: ins(%[[ARG0]] : tensor<3x5x7xf32>)
|
||||
// CHECK-SAME: outs(%[[T1]] : tensor<5x21xf32>)
|
||||
|
||||
// -----
|
||||
|
||||
#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
||||
func.func @generic_op_reshape_consumer_nofusion(%arg0 : tensor<?x?x?x5xf32>,
|
||||
%arg1 : tensor<?x?x?x5xf32>) ->
|
||||
tensor<?x?xf32>
|
||||
{
|
||||
%0 = linalg.generic {
|
||||
indexing_maps = [#map0, #map0, #map0],
|
||||
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
|
||||
ins(%arg0, %arg1 : tensor<?x?x?x5xf32>, tensor<?x?x?x5xf32>)
|
||||
outs(%arg0 : tensor<?x?x?x5xf32>) {
|
||||
^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
|
||||
%1 = arith.mulf %arg3, %arg4 : f32
|
||||
linalg.yield %1 : f32
|
||||
} -> tensor<?x?x?x5xf32>
|
||||
%1 = tensor.collapse_shape %0 [[0], [1, 2, 3]] :
|
||||
tensor<?x?x?x5xf32> into tensor<?x?xf32>
|
||||
return %1 : tensor<?x?xf32>
|
||||
}
|
||||
// CHECK-LABEL: func @generic_op_reshape_consumer_nofusion
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?x5xf32>
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?x?x5xf32>
|
||||
// CHECK: %[[NOFUSE:.+]] = linalg.generic
|
||||
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]]
|
||||
// CHECK: %[[RESULT:.+]] = tensor.collapse_shape %[[NOFUSE]]
|
||||
// CHECK: return %[[RESULT]]
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
func.func @generic_op_permultation_reshape_consumer_fusion_unused_dim(%arg0 : tensor<6x1xf32>) -> tensor<6xi32> {
|
||||
%0 = linalg.init_tensor [6, 1] : tensor<6x1xi32>
|
||||
%1 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
|
||||
affine_map<(d0, d1) -> (d0, d1)>],
|
||||
iterator_types = ["parallel", "parallel"]}
|
||||
ins(%arg0 : tensor<6x1xf32>) outs(%0 : tensor<6x1xi32>) {
|
||||
^bb0(%arg3: f32, %arg4: i32):
|
||||
%5 = arith.fptosi %arg3 : f32 to i32
|
||||
linalg.yield %5 : i32
|
||||
} -> tensor<6x1xi32>
|
||||
%6 = tensor.collapse_shape %1 [[0, 1]] : tensor<6x1xi32> into tensor<6xi32>
|
||||
return %6 : tensor<6xi32>
|
||||
}
|
||||
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0)>
|
||||
// CHECK: func @generic_op_permultation_reshape_consumer_fusion_unused_dim
|
||||
// CHECK-SAME: %[[ARG0:.+]]: tensor<6x1xf32>
|
||||
// CHECK: %[[T0:.+]] = linalg.init_tensor [6, 1]
|
||||
// CHECK: %[[T1:.+]] = tensor.collapse_shape %[[T0]]
|
||||
// CHECK-SAME: [0, 1]
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
|
||||
// CHECK-SAME: ins(%[[ARG0]] : tensor<6x1xf32>)
|
||||
// CHECK-SAME: outs(%[[T1]] : tensor<6xi32>)
|
||||
|
||||
// -----
|
||||
|
||||
#map0 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d2, d4, d0, d6, d3, d5, d1)>
|
||||
#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4, d5, d6)>
|
||||
func.func @permuted_dims_fusion_expand_shape(%arg0 : tensor<3x8x7x240xf32>) -> tensor<4x6x3x8x2x5x7xf32> {
|
||||
%0 = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5, 6]]
|
||||
: tensor<3x8x7x240xf32> into tensor<3x2x4x7x8x5x6xf32>
|
||||
%1 = linalg.init_tensor [4, 6, 3, 8, 2, 5, 7] : tensor<4x6x3x8x2x5x7xf32>
|
||||
%2 = linalg.generic {
|
||||
indexing_maps = [#map0, #map1],
|
||||
iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]}
|
||||
ins(%0 : tensor<3x2x4x7x8x5x6xf32>) outs(%1 : tensor<4x6x3x8x2x5x7xf32>) {
|
||||
^bb0(%arg1 : f32, %arg2 : f32):
|
||||
linalg.yield %arg1 : f32
|
||||
} -> tensor<4x6x3x8x2x5x7xf32>
|
||||
return %2 : tensor<4x6x3x8x2x5x7xf32>
|
||||
}
|
||||
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d2, d0 + d4 * 4, d6, d1 + d3 * 30 + d5 * 6)>
|
||||
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4, d5, d6)>
|
||||
// CHECK: func @permuted_dims_fusion_expand_shape(
|
||||
// CHECK-SAME: %[[ARG0:.+]]: tensor<3x8x7x240xf32>)
|
||||
// CHECK: %[[RESULT:.+]] = linalg.generic
|
||||
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
|
||||
// CHECK-SAME: ins(%[[ARG0]] :
|
||||
// CHECK: return %[[RESULT]]
|
||||
|
||||
// -----
|
||||
|
||||
#map0 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d2, d4, d0, d6, d3, d5, d1)>
|
||||
#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4, d5, d6)>
|
||||
func.func @permuted_dims_fusion_collapse_shape(%arg0 : tensor<4x6x3x8x2x5x7xf32>) -> tensor<3x8x7x240xf32> {
|
||||
%0 = linalg.init_tensor [3, 2, 4, 7, 8, 5, 6] : tensor<3x2x4x7x8x5x6xf32>
|
||||
%1 = linalg.generic {
|
||||
indexing_maps = [#map1, #map0],
|
||||
iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]}
|
||||
ins(%arg0 : tensor<4x6x3x8x2x5x7xf32>) outs(%0 : tensor<3x2x4x7x8x5x6xf32>) {
|
||||
^bb0(%arg1 : f32, %arg2 : f32):
|
||||
linalg.yield %arg1 : f32
|
||||
} -> tensor<3x2x4x7x8x5x6xf32>
|
||||
%2 = tensor.collapse_shape %1 [[0], [1, 2], [3], [4, 5, 6]]
|
||||
: tensor<3x2x4x7x8x5x6xf32> into tensor<3x8x7x240xf32>
|
||||
return %2 : tensor<3x8x7x240xf32>
|
||||
}
|
||||
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4, d5, d6)>
|
||||
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d2, d0 + d4 * 4, d6, d1 + d3 * 30 + d5 * 6)>
|
||||
// CHECK: func @permuted_dims_fusion_collapse_shape(
|
||||
// CHECK-SAME: %[[ARG0:.+]]: tensor<4x6x3x8x2x5x7xf32>)
|
||||
// CHECK: %[[RESULT:.+]] = linalg.generic
|
||||
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
|
||||
// CHECK-SAME: ins(%[[ARG0]] :
|
||||
// CHECK: return %[[RESULT]]
|
|
@ -1,52 +0,0 @@
|
|||
// RUN: mlir-opt -linalg-fold-reshape-ops-by-linearization=allow-folding-unit-dim-reshapes -split-input-file %s | FileCheck %s
|
||||
|
||||
#map = affine_map<(d0, d1) -> (d0, d1)>
|
||||
func.func @do_not_fold1(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?x1xf32>
|
||||
{
|
||||
%c0 = arith.constant 0 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
%0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
|
||||
%1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
|
||||
%2 = linalg.init_tensor [%0, %1] : tensor<?x?xf32>
|
||||
%3 = linalg.generic {
|
||||
indexing_maps = [#map, #map, #map],
|
||||
iterator_types = ["parallel", "parallel"]}
|
||||
ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
|
||||
outs(%2 : tensor<?x?xf32>) {
|
||||
^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32):
|
||||
%4 = arith.addf %arg2, %arg3 : f32
|
||||
linalg.yield %4 : f32
|
||||
} -> tensor<?x?xf32>
|
||||
%4 = tensor.expand_shape %3 [[0], [1, 2]] : tensor<?x?xf32> into tensor<?x?x1xf32>
|
||||
return %4 : tensor<?x?x1xf32>
|
||||
}
|
||||
// CHECK-LABEL: func @do_not_fold1
|
||||
// CHECK: %[[VAL:.+]] = linalg.generic
|
||||
// CHECK: tensor.expand_shape %[[VAL]]
|
||||
|
||||
// -----
|
||||
|
||||
#map = affine_map<(d0, d1) -> (d0, d1)>
|
||||
func.func @do_not_fold2(%arg0 : tensor<?x?x1xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
{
|
||||
%c0 = arith.constant 0 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
%0 = tensor.collapse_shape %arg0 [[0], [1, 2]] : tensor<?x?x1xf32> into tensor<?x?xf32>
|
||||
%1 = tensor.dim %arg1, %c0 : tensor<?x?xf32>
|
||||
%2 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
|
||||
%3 = linalg.init_tensor [%1, %2] : tensor<?x?xf32>
|
||||
%4 = linalg.generic {
|
||||
indexing_maps = [#map, #map, #map],
|
||||
iterator_types = ["parallel", "parallel"]}
|
||||
ins(%0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
|
||||
outs(%3 : tensor<?x?xf32>) {
|
||||
^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32):
|
||||
%4 = arith.addf %arg2, %arg3 : f32
|
||||
linalg.yield %4 : f32
|
||||
} -> tensor<?x?xf32>
|
||||
return %4 : tensor<?x?xf32>
|
||||
}
|
||||
// CHECK-LABEL: func @do_not_fold2
|
||||
// CHECK: %[[VAL:.+]] = tensor.collapse_shape
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-SAME: ins(%[[VAL]], %{{.+}} : tensor<?x?xf32>, tensor<?x?xf32>)
|
|
@ -70,18 +70,18 @@ struct TestLinalgElementwiseFusion
|
|||
llvm::cl::desc("Test fusion of generic operations."),
|
||||
llvm::cl::init(false)};
|
||||
|
||||
Option<bool> fuseWithReshapeByExpansion{
|
||||
*this, "fuse-with-reshape-by-expansion",
|
||||
llvm::cl::desc(
|
||||
"Test fusion of generic operations with reshape by expansion"),
|
||||
llvm::cl::init(false)};
|
||||
|
||||
Option<bool> controlFuseByExpansion{
|
||||
*this, "control-fusion-by-expansion",
|
||||
llvm::cl::desc(
|
||||
"Test controlling fusion of reshape with generic op by expansion"),
|
||||
llvm::cl::init(false)};
|
||||
|
||||
Option<bool> pushExpandingReshape{
|
||||
*this, "push-expanding-reshape",
|
||||
llvm::cl::desc("Test linalg expand_shape -> generic "
|
||||
"to generic -> expand_shape pattern"),
|
||||
llvm::cl::init(false)};
|
||||
|
||||
Option<bool> fuseWithReshapeByCollapsing{
|
||||
*this, "fuse-with-reshape-by-collapsing",
|
||||
llvm::cl::desc("Test linalg expand_shape -> generic fusion patterns that "
|
||||
|
@ -109,6 +109,17 @@ struct TestLinalgElementwiseFusion
|
|||
return;
|
||||
}
|
||||
|
||||
if (fuseWithReshapeByExpansion) {
|
||||
RewritePatternSet fusionPatterns(context);
|
||||
linalg::populateFoldReshapeOpsByExpansionPatterns(
|
||||
fusionPatterns, [](const OpResult & /*producer*/,
|
||||
OpOperand & /*consumer*/) { return true; });
|
||||
if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
|
||||
std::move(fusionPatterns))))
|
||||
return signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
if (controlFuseByExpansion) {
|
||||
RewritePatternSet fusionPatterns(context);
|
||||
|
||||
|
@ -128,8 +139,9 @@ struct TestLinalgElementwiseFusion
|
|||
if (linalgOp && linalgOp.isOutputTensor(&use))
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return linalg::skipUnitDimReshape(producer, consumer);
|
||||
return true;
|
||||
};
|
||||
|
||||
linalg::populateFoldReshapeOpsByExpansionPatterns(fusionPatterns,
|
||||
|
@ -139,12 +151,6 @@ struct TestLinalgElementwiseFusion
|
|||
return;
|
||||
}
|
||||
|
||||
if (pushExpandingReshape) {
|
||||
RewritePatternSet patterns(context);
|
||||
linalg::populatePushReshapeOpsPatterns(patterns);
|
||||
(void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
|
||||
}
|
||||
|
||||
if (fuseWithReshapeByCollapsing) {
|
||||
RewritePatternSet patterns(context);
|
||||
linalg::populateFoldReshapeOpsByCollapsingPatterns(
|
||||
|
|
Loading…
Reference in New Issue