[mlir][Linalg] Fuse sequence of Linalg operation (on buffers)

Enhance the tile+fuse logic to allow fusing a sequence of operations.

Make sure the value used to obtain tile shape is a
SubViewOp/SubTensorOp. Current logic used to get the bounds of loop
depends on the use of `getOrCreateRange` method on `SubViewOp` and
`SubTensorOp`. Make sure that the value/dim used to compute the range
is from such ops.  This fix is a reasonable WAR, but a btter fix would
be to make `getOrCreateRange` method be a method of `ViewInterface`.

Differential Revision: https://reviews.llvm.org/D90991
This commit is contained in:
MaheshRavishankar 2020-11-23 10:07:34 -08:00
parent 4252f7773a
commit e65a5e5b00
7 changed files with 620 additions and 314 deletions

View File

@ -37,14 +37,6 @@ struct TiledLinalgOp {
SmallVector<Value, 4> tensorResults;
};
struct TiledAndFusedLinalgOps {
LinalgOp op;
SmallVector<LinalgOp, 1> fusedProducers;
SmallVector<LinalgOp, 1> originalProducers;
SmallVector<Operation *, 4> fusedLoops;
SmallVector<Operation *, 4> unfusedLoops;
};
/// Populates patterns for vectorization of all ConvN-D ops.
void populateConvVectorizationPatterns(
MLIRContext *context, SmallVectorImpl<OwningRewritePatternList> &patterns,
@ -73,14 +65,11 @@ void populateLinalgBufferizePatterns(MLIRContext *context,
Optional<TiledLinalgOp> tileLinalgOp(OpBuilder &b, LinalgOp op,
const LinalgTilingOptions &options);
/// Tile and fuse the `op` with its producers. The tile and fuse proceeds in
/// three steps
/// - Find tile loops that are fusable with its producer tile loops (a.k.a. tile
/// + fuse loops).
/// - Tile just these loops of the consumer (root operation) and fuse with
/// the producer.
/// - Tile again the tiled consumer operation produced above to do rest of
/// the tiling specified by the `tilingOptions`.
/// Fuse a sequence of linalg operations (`ops`) using tile-and-fuse. This
/// proceeds as follows:
/// - Find outer parallel loops in these ops that can be fused.
/// - Tile fusable outer parallel loops of the last operation in the sequence.
/// - Fuse the remaining operations with the tiled operation
///
/// For example, consider the sequence of matmul below
///
@ -107,36 +96,39 @@ Optional<TiledLinalgOp> tileLinalgOp(OpBuilder &b, LinalgOp op,
/// : memref<256x32xf32> to memref<16x32xf32, #map0>
/// %3 = subview %arg1[0, 0] [32, 32] [1, 1]
/// : memref<32x32xf32> to memref<32x32xf32, #map1>
/// %4 = subview %arg3[0, 0] [32, 32] [1, 1]
/// : memref<32x32xf32> to memref<32x32xf32, #map1>
/// linalg.matmul
/// ins(%2, %3 : memref<16x32xf32, #map0>, memref<32x32xf32, #map1>)
/// outs(%0 : memref<16x32xf32, #map0>)
/// scf.parallel (%arg6) = (%c0) to (%c32) step (%c8) {
/// scf.for %arg7 = %c0 to %c32 step %c4 {
/// %4 = subview %0[0, %arg7] [16, 4] [1, 1]
/// : memref<16x32xf32, #map0> to memref<16x4xf32, #map0>
/// %5 = subview %arg3[%arg7, %arg6] [4, 8] [1, 1]
/// : memref<32x32xf32> to memref<4x8xf32, #map0>
/// %6 = subview %1[0, %arg6] [16, 8] [1, 1]
/// : memref<16x32xf32, #map0> to memref<16x8xf32, #map0>
/// linalg.matmul
/// ins(%4, %5 : memref<16x4xf32, #map0>, memref<4x8xf32, #map0>)
/// outs(%6 : memref<16x8xf32, #map0>)
/// }
/// scf.yield
/// }
/// scf.yield
/// linalg.matmul
/// ins(%0, %4 : memref<16x4xf32, #map0>, memref<4x8xf32, #map0>)
/// outs(%1 : memref<16x8xf32, #map0>)
/// }
///
/// The following tiling options are handled differently in tile+fuse (compared
/// to tile only)
/// `tilingOptions` are used to tile the corresponding operation in `ops` (the
/// size of the former should be same as size of the latter. Based on how
/// tile+fuse is implemented, the fused loops are generated based on the last
/// operation in the sequence. For example, the tile sizes for the fused loops
/// is obtained from `tilingOptions.back()`. The following tiling options are
/// handled differently in tile+fuse (compared to tile only)
/// - Interchange of the tiling loops is not supported right now.
/// - Distribution is only done for the tile+fuse loops. The tiled loops
/// generated by the second tiling is not distributed.
/// - Only the fused loops are distributed.
struct TiledAndFusedLinalgOps {
/// Operation obtained by tiling the last operation in sequence of `ops`
/// passed to `tileAndFuseLinalgOps`.
LinalgOp op;
/// The dimension of the loops that are fused.
std::set<unsigned> fusedLoopDims;
/// The generated fused operations (created within the fused loops).
SmallVector<LinalgOp, 1> fusedProducers;
/// The fused loop generated.
SmallVector<Operation *, 4> fusedLoops;
};
Optional<TiledAndFusedLinalgOps>
tileAndFuseLinalgOps(PatternRewriter &rewriter, LinalgOp op,
tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef<LinalgOp> ops,
const LinalgDependenceGraph &dependenceGraph,
const LinalgTilingOptions &tilingOptions,
const LinalgFusionOptions &fusionOptions);
const LinalgTilingOptions &tilingOptions);
/// Interchanges the `iterator_types` and `iterator_maps` dimensions of `op`.
/// This is an in-place transformation controlled by `interchangeVector`.

View File

@ -162,13 +162,24 @@ struct ShapeDimension {
// guarantees at least one such dimension is found. If multiple candidates exist
// they must agree by construction (i.e. have the same size) and we just return
// the first one.
static ShapeDimension getShapeDefiningLoopRange(LinalgOp op,
unsigned loopDepth) {
static ShapeDimension
getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth,
bool fromSubViewOpOnly = false) {
auto maps = op.indexing_maps();
// Iterate over the inputs and outputs in order.
// Extract the subranges from the linearized ranges.
SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers());
for (auto en : llvm::enumerate(ios)) {
// The method `getRangeFromOperandShape` requires using SubViewOp or
// SubTensorOps. If the value isnt defined from there continue.
// todo: The method should be adapted to get the values from
// `ViewInterface`. The interface needs a `getOrCreateRanges` method which
// currently returns a `linalg.range`. The fix here is to move this op to
// `std` dialect and add the method to `ViewInterface`.
if (fromSubViewOpOnly &&
!isa_and_nonnull<SubViewOp, SubTensorOp>(en.value().getDefiningOp()))
continue;
unsigned idx = en.index();
auto map = maps[idx].cast<AffineMapAttr>().getValue();
LLVM_DEBUG(llvm::dbgs()
@ -178,6 +189,9 @@ static ShapeDimension getShapeDefiningLoopRange(LinalgOp op,
Value shape = en.value();
SmallVector<Value, 8> shapeRanges(map.getNumResults(), nullptr);
for (auto en2 : llvm::enumerate(map.getResults())) {
auto dimExpr = en2.value().dyn_cast<AffineDimExpr>();
if (!dimExpr)
continue;
if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) {
LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange loopDepth: "
<< loopDepth << "\n");
@ -190,49 +204,18 @@ static ShapeDimension getShapeDefiningLoopRange(LinalgOp op,
llvm_unreachable("Expect to be able to extract a shape defining loop range");
}
/// Fuses the producer of `producerIdx` into the loop immediately enclosing
/// `consumer`. This is achieved by "recomputing" the `producer` at the time it
/// is needed just before the `consumer.
///
/// Depending on the type of `consumer.getShapedOperand(consumerIdx)`, there are
/// 2 cases:
/// 1. Buffer case: `producerIdx` is the index of the buffer in
/// `producer.getOutputBuffers()`.
/// 2. Tensor case: `producerIdx` is the index of the tensor in
/// `producer.getResults()`.
static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx,
LinalgOp consumer, unsigned consumerIdx) {
Operation *shapeProducingOp =
consumer.getShapedOperand(consumerIdx).getDefiningOp();
assert((isa<SubViewOp>(shapeProducingOp) ||
isa<SubTensorOp>(shapeProducingOp)) &&
"SubviewOp or SubTensorOp expected");
// loopToOperandRangesMaps are permutations-only by construction:
// we can always identify a data dimension with a (at least one) loop
// dimension.
// TODO: extend this with range inference.
AffineMap producerMap = producer.getOutputIndexingMap(producerIdx);
LLVM_DEBUG(llvm::dbgs() << "Producer Idx: " << producerIdx
<< ", producer map: " << producerMap << "\n");
/// Fuse the producer by cloning the `producer`. The `fusedLoopsAndRanges`
/// provides the loop range information for the fused loops. The rest are
/// obtained from the producer itself, since they are not tiled + fused.
static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
const DenseMap<unsigned, Range> &fusedLoopsAndRanges) {
unsigned nPar = producer.getNumParallelLoops();
unsigned nRed = producer.getNumReductionLoops();
unsigned nWin = producer.getNumWindowLoops();
SmallVector<Range, 8> loopRanges(nPar + nRed + nWin);
// Iterate over dimensions identified by the producer map for `producerIdx`.
// This defines a subset of the loop ranges that we need to complete later.
auto loc = consumer.getLoc();
for (auto en : llvm::enumerate(producerMap.getResults())) {
unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition();
loopRanges[posInProducerLoop] =
isa<SubViewOp>(shapeProducingOp)
? cast<SubViewOp>(shapeProducingOp)
.getOrCreateRanges(b, loc)[en.index()]
: cast<SubTensorOp>(shapeProducingOp)
.getOrCreateRanges(b, loc)[en.index()];
}
for (auto fusedLoops : fusedLoopsAndRanges)
loopRanges[fusedLoops.first] = fusedLoops.second;
// Iterate over all dimensions. For the dimensions not identified by the
// producer map for `producerIdx`, we need to explicitly compute the shape
@ -250,7 +233,45 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx,
}
}
return cloneWithLoopRanges(b, loc, producer, loopRanges);
return cloneWithLoopRanges(b, producer.getLoc(), producer, loopRanges);
}
/// Get the loop range for a dimension `dim` based on the `shapedOperand`. It is
/// expected to be defined by a subview op or a subtensor op.
static Range getRangeFromOperandShape(OpBuilder &b, Location loc,
Value shapedOperand, unsigned dim) {
Operation *shapeProducingOp = shapedOperand.getDefiningOp();
if (auto subViewOp = dyn_cast<SubViewOp>(shapeProducingOp))
return subViewOp.getOrCreateRanges(b, loc)[dim];
if (auto subTensorOp = dyn_cast<SubTensorOp>(shapeProducingOp))
return subTensorOp.getOrCreateRanges(b, loc)[dim];
llvm_unreachable("SubviewOp or SubTensorOp expected");
}
/// Fuses the producer of `producerIdx` into the loop immediately enclosing
/// `consumer`. This is achieved by "recomputing" the `producer` at the time it
/// is needed just before the `consumer.
///
/// Depending on the type of `consumer.getShapedOperand(consumerIdx)`, there are
/// 2 cases:
/// 1. Buffer case: `producerIdx` is the index of the buffer in
/// `producer.getOutputBuffers()`.
/// 2. Tensor case: `producerIdx` is the index of the tensor in
/// `producer.getResults()`.
static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx,
LinalgOp consumer, unsigned consumerIdx) {
AffineMap producerMap = producer.getOutputIndexingMap(producerIdx);
LLVM_DEBUG(llvm::dbgs() << "Producer Idx: " << producerIdx
<< ", producer map: " << producerMap << "\n");
DenseMap<unsigned, Range> fusedLoopsAndRanges;
Location loc = consumer.getLoc();
Value shapedOperand = consumer.getShapedOperand(consumerIdx);
for (auto en : llvm::enumerate(producerMap.getResults())) {
unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition();
fusedLoopsAndRanges[posInProducerLoop] =
getRangeFromOperandShape(b, loc, shapedOperand, en.index());
}
return fuse(b, producer, fusedLoopsAndRanges);
}
// Encode structural fusion safety preconditions.
@ -525,6 +546,69 @@ using FusableOpDependencesTy = llvm::MapVector<
Operation *,
SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 1>>;
/// Returns the mapping from iterations in the consumer that write to the same
/// location as the iterations in the producer. To do so use
/// - indexing map of the fused view in the consumer : consumerIndexMap
/// - indexing map of the fused view in the producer : producerIndexMap
/// consumerLoopToProducerLoop =
/// inverse(producerIndexMap).compose(consumerIndexMap)
static Optional<AffineMap> getConsumerLoopToProducerLoopMap(
LinalgDependenceGraph::LinalgDependenceGraphElem dependence) {
auto producer = cast<LinalgOp>(dependence.dependentOpView.op);
AffineMap producerIndexingMap =
producer.getIndexingMap(dependence.dependentOpView.operandIndex);
auto consumer = cast<LinalgOp>(dependence.indexingOpView.op);
AffineMap consumerIndexingMap =
consumer.getIndexingMap(dependence.indexingOpView.operandIndex);
AffineMap prunedProducerIndexingMap = pruneReductionDimsFromMap(
producer.iterator_types().getValue(), producerIndexingMap);
if (!prunedProducerIndexingMap.isPermutation())
return None;
if (consumerIndexingMap.getNumResults() !=
prunedProducerIndexingMap.getNumResults())
return None;
LLVM_DEBUG({
llvm::dbgs() << "\t producerMap : ";
producerIndexingMap.print(llvm::dbgs());
llvm::dbgs() << " pruned : ";
prunedProducerIndexingMap.print(llvm::dbgs());
llvm::dbgs() << "\n";
llvm::dbgs() << "\t consumerMap : ";
consumerIndexingMap.print(llvm::dbgs());
llvm::dbgs() << "\n";
});
AffineMap invProducerIndexMap = inversePermutation(prunedProducerIndexingMap);
if (!invProducerIndexMap)
return None;
return invProducerIndexMap.compose(consumerIndexingMap);
}
/// Given a projected permutation `map`, returns true if the map changes the
/// order in which the fused loop dimension appear.
static bool doesTransposeAccess(AffineMap map,
const std::set<unsigned> &fusableLoops) {
Optional<unsigned> lastFusableLoop;
for (unsigned pos : llvm::map_range(map.getResults(), [](AffineExpr expr) {
return expr.cast<AffineDimExpr>().getPosition();
})) {
if (!fusableLoops.count(pos))
continue;
if (!lastFusableLoop) {
lastFusableLoop = pos;
continue;
}
if (pos <= lastFusableLoop.getValue())
return true;
lastFusableLoop = pos;
}
return false;
}
/// Returns the positions of the loop in `op` that can be tiled based on the
/// operations that are to be fused with it. For example, in a
///
@ -538,13 +622,7 @@ using FusableOpDependencesTy = llvm::MapVector<
/// 2. Of the parallel loops only some can be fused. Only those loops can be
/// fused such where the fusable loops iteration space only touches one tile
/// of the fused operation. This is because the producer (which is writing
/// the fused subview) has update semantics. To compute this,
/// a. Find the mapping from iterations in the consumer that write to the
/// same location as the iterations in the producer. To do so use
/// - indexing map of the fused view in the consumer : consumerIndexMap
/// - indexing map of the fused view in the producer : producerIndexMap
/// consumerLoopToProducerLoop =
/// inverse(producerIndexMap).compose(consumerIndexMap)
/// the fused subview) has update semantics.
///
/// Since an inverse computation is needed, we need to consider the projection
/// of the producerIndexMap w.r.t the parallel loops. The actual fusable loops
@ -582,8 +660,9 @@ using FusableOpDependencesTy = llvm::MapVector<
/// submap with only parallel loops = affine_map<(i, j) -> (j)>
/// Fused dimensions : j
static std::set<unsigned>
collectTileAndFuseLoops(LinalgOp op,
const FusableOpDependencesTy &fusableDependences) {
collectFusableLoops(ArrayRef<LinalgOp> ops,
const FusableOpDependencesTy &fusableDependences) {
assert(!ops.empty());
auto getNumOuterParallelLoops = [](LinalgOp linalgOp) {
return linalgOp.iterator_types()
.getValue()
@ -594,88 +673,57 @@ collectTileAndFuseLoops(LinalgOp op,
.size();
};
LLVM_DEBUG({
llvm::dbgs() << "Op : ";
op.getOperation()->print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
llvm::dbgs() << "\n";
});
size_t numOuterParallelLoops = getNumOuterParallelLoops(op);
for (auto dependence : fusableDependences) {
linalg::LinalgOp producer = cast<linalg::LinalgOp>(dependence.first);
size_t numOuterParallelLoops = getNumOuterParallelLoops(ops.back());
for (auto op : ops.drop_back()) {
numOuterParallelLoops =
std::min(numOuterParallelLoops, getNumOuterParallelLoops(producer));
std::min(numOuterParallelLoops, getNumOuterParallelLoops(op));
}
std::set<unsigned> fusableLoops;
auto range = llvm::seq<unsigned>(0, numOuterParallelLoops);
fusableLoops.insert(range.begin(), range.end());
for (auto dependence : fusableDependences) {
LLVM_DEBUG({
llvm::dbgs() << "\t fusable :";
for (unsigned i : fusableLoops)
llvm::dbgs() << " " << i;
llvm::dbgs() << "\n";
});
linalg::LinalgOp producer = cast<linalg::LinalgOp>(dependence.first);
assert(!dependence.second.empty() &&
"unexpected producer but not dependences");
AffineMap producerIndexingMap = producer.getIndexingMap(
dependence.second.front().dependentOpView.operandIndex);
AffineMap prunedProducerIndexingMap = pruneReductionDimsFromMap(
producer.iterator_types().getValue(), producerIndexingMap);
if (!prunedProducerIndexingMap.isPermutation())
return {};
for (auto op : reverse(ops)) {
for (auto dependence : fusableDependences.lookup(op)) {
LLVM_DEBUG({
llvm::dbgs() << "\t fusable :";
for (unsigned i : fusableLoops)
llvm::dbgs() << " " << i;
llvm::dbgs() << "\n";
});
AffineMap consumerIndexingMap = op.getIndexingMap(
dependence.second.front().indexingOpView.operandIndex);
if (consumerIndexingMap.getNumResults() !=
prunedProducerIndexingMap.getNumResults())
return {};
Optional<AffineMap> consumerLoopToProducerLoop =
getConsumerLoopToProducerLoopMap(dependence);
if (!consumerLoopToProducerLoop) {
op.emitRemark("failed to get map from consumer loop to producer loop");
return {};
}
// todo: This condition is only an implementation limitation. When fusing
// the operation, if the accesses in the producer/consumer are transposes
// of each other, the loop bounds for the tiled producer can be
// manipulated accordingly. This requires some additional bookkeeping in
// the implementation of tile+fuse that is defered to later.
if (doesTransposeAccess(*consumerLoopToProducerLoop, fusableLoops)) {
op.emitRemark("unhandled fusion when fusion requires permutation");
return {};
}
LLVM_DEBUG({
llvm::dbgs() << "\t producerMap : ";
producerIndexingMap.print(llvm::dbgs());
llvm::dbgs() << " pruned : ";
prunedProducerIndexingMap.print(llvm::dbgs());
llvm::dbgs() << "\n";
llvm::dbgs() << "\t consumerMap : ";
consumerIndexingMap.print(llvm::dbgs());
llvm::dbgs() << "\n";
});
AffineMap invProducerIndexMap =
inversePermutation(prunedProducerIndexingMap);
if (!invProducerIndexMap)
return {};
AffineMap consumerLoopToProducerLoop =
invProducerIndexMap.compose(consumerIndexingMap);
LLVM_DEBUG({
llvm::dbgs() << "\t consumerLoopToProducerLoop : ";
consumerLoopToProducerLoop.print(llvm::dbgs());
});
std::set<unsigned> candidates;
for (AffineExpr expr : consumerLoopToProducerLoop.getResults()) {
AffineDimExpr dimExpr = expr.dyn_cast<AffineDimExpr>();
if (!dimExpr)
continue;
unsigned position = dimExpr.getPosition();
if (fusableLoops.count(position))
candidates.insert(position);
std::set<unsigned> candidates;
for (AffineExpr expr : consumerLoopToProducerLoop->getResults()) {
unsigned position = expr.cast<AffineDimExpr>().getPosition();
if (fusableLoops.count(position))
candidates.insert(position);
}
LLVM_DEBUG({
llvm::dbgs() << "\t candidates :";
for (unsigned i : candidates)
llvm::dbgs() << " " << i;
llvm::dbgs() << "\n";
});
if (candidates.empty())
return {};
std::swap(candidates, fusableLoops);
}
LLVM_DEBUG({
llvm::dbgs() << "\t candidates :";
for (unsigned i : candidates)
llvm::dbgs() << " " << i;
llvm::dbgs() << "\n";
});
if (candidates.empty())
return {};
std::swap(candidates, fusableLoops);
}
return fusableLoops;
@ -683,60 +731,69 @@ collectTileAndFuseLoops(LinalgOp op,
/// Find all dependences that are to be fusable.
static FusableOpDependencesTy
findAllFusableDependences(LinalgOp op,
const LinalgDependenceGraph &dependenceGraph,
const LinalgFusionOptions &fusionOptions) {
findAllFusableDependences(ArrayRef<LinalgOp> ops,
const LinalgDependenceGraph &dependenceGraph) {
FusableOpDependencesTy fusableDependences;
// TODO: Currently fusion would not be legal if the fusable dependence is to
// the same producer but different indexing map in the consumer. Fix this, but
// in the meanwhile disallow such a fusion.
DenseMap<Operation *, AffineMap> fusedProducerIndexingMap;
for (auto operandIndex : fusionOptions.indicesToFuse) {
auto fusableDependence =
findFusableProducer(op, operandIndex, dependenceGraph);
if (!fusableDependence)
return FusableOpDependencesTy{};
LinalgOp producerOp = cast<LinalgOp>(fusableDependence->dependentOpView.op);
// Do not fuse dependences that are to operations not in the same basic
// block. This avoid moving fused operations across loops that might
// themselves carry dependency making the fusion illegal.
if (producerOp.getOperation()->getBlock() !=
op.getOperation()->getBlock()) {
op.emitRemark("unhandled fusion of ops in different basic blocks");
return FusableOpDependencesTy{};
}
// Make sure that the indexing map of the view used for fusion in the
// producer is a projected permutation.
unsigned producerIdx = fusableDependence->dependentOpView.operandIndex;
AffineMap producerMap = producerOp.getIndexingMap(producerIdx);
if (!producerMap.isProjectedPermutation()) {
op.emitRemark("unhandled non permutation indexing map for fused view in "
"producer for operand at index ")
<< operandIndex;
return FusableOpDependencesTy{};
}
for (LinalgOp op : reverse(ops)) {
for (auto operandIndex :
llvm::seq<unsigned>(0, op.getNumInputsAndOutputBuffers())) {
Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
fusableDependence =
findFusableProducer(op, operandIndex, dependenceGraph);
if (!fusableDependence)
continue;
LinalgOp producerOp =
cast<LinalgOp>(fusableDependence->dependentOpView.op);
// Do not fuse dependences that are to operations not in the same basic
// block. This avoid moving fused operations across loops that might
// themselves carry dependency making the fusion illegal.
if (producerOp.getOperation()->getBlock() !=
op.getOperation()->getBlock()) {
op.emitRemark("unhandled fusion of ops in different basic blocks");
return FusableOpDependencesTy{};
}
// Make sure that the indexing map of the view used for fusion in the
// producer is a projected permutation.
unsigned producerIdx = fusableDependence->dependentOpView.operandIndex;
AffineMap producerMap = producerOp.getIndexingMap(producerIdx);
if (!producerMap.isProjectedPermutation()) {
op.emitRemark(
"unhandled non permutation indexing map for fused view in "
"producer for operand at index ")
<< operandIndex;
return FusableOpDependencesTy{};
}
unsigned consumerIdx = fusableDependence->indexingOpView.operandIndex;
AffineMap consumerMap = op.getIndexingMap(consumerIdx);
if (!consumerMap.isProjectedPermutation()) {
op.emitRemark(
"unhandled case where indexing map for fused view in the consumer is "
"not a projected permutation while fusing at index ")
<< operandIndex;
return FusableOpDependencesTy{};
}
unsigned consumerIdx = fusableDependence->indexingOpView.operandIndex;
AffineMap consumerMap = op.getIndexingMap(consumerIdx);
if (!consumerMap.isProjectedPermutation()) {
op.emitRemark(
"unhandled case where indexing map for fused view in the consumer "
"is "
"not a projected permuration while fusing at index ")
<< operandIndex;
return FusableOpDependencesTy{};
}
// Check if the producer is already a fusion candidate. Cannot fuse this
// dependence if it has a different indexing map when used in the consumer.
if (fusedProducerIndexingMap.count(producerOp.getOperation()) &&
fusedProducerIndexingMap[producerOp.getOperation()] != consumerMap) {
op.emitRemark("unhandled fusion to the same producer but with different "
"indexing maps");
return FusableOpDependencesTy{};
}
fusedProducerIndexingMap[producerOp.getOperation()] = consumerMap;
// Check if the producer is already a fusion candidate. Cannot fuse this
// dependence if it has a different indexing map when used in the
// consumer.
if (fusedProducerIndexingMap.count(producerOp.getOperation()) &&
fusedProducerIndexingMap[producerOp.getOperation()] != consumerMap) {
op.emitRemark(
"unhandled fusion to the same producer but with different "
"indexing maps");
return FusableOpDependencesTy{};
}
fusedProducerIndexingMap[producerOp.getOperation()] = consumerMap;
fusableDependences[producerOp.getOperation()].push_back(*fusableDependence);
fusableDependences[producerOp.getOperation()].push_back(
*fusableDependence);
}
}
return fusableDependences;
}
@ -747,136 +804,120 @@ static bool isZero(Value v) {
return false;
}
/// Tile the fused loops in the root operation, by setting the tile sizes for
/// all other loops to zero (those will be tiled later).
static Optional<TiledLinalgOp> tileRootOperation(
OpBuilder &builder, LinalgOp op, ArrayRef<Value> tileSizeVector,
const LinalgTilingOptions &options, const std::set<unsigned> &fusedLoops) {
SmallVector<Value, 4> tileSizes(tileSizeVector.begin(), tileSizeVector.end());
auto zero = std_constant_index(0);
for (unsigned i = 0, e = tileSizes.size(); i != e; ++i)
if (!fusedLoops.count(i))
tileSizes[i] = zero;
LinalgTilingOptions tileFusedLoopsOptions = options;
tileFusedLoopsOptions.setTileSizes(tileSizes);
return tileLinalgOp(builder, op, tileFusedLoopsOptions);
}
/// Fuse the operations in `fusionCandidates` with `tiledOp`. Latter is expected
/// to be a tiled operation such that it is valid to fuse all operations in
/// `fusionCandidates`, i.e. move the operation within the inter-tile loops of
/// `tiledOp`.
static SmallVector<LinalgOp, 1>
fuseOperations(OpBuilder &builder, LinalgOp tiledOp,
ArrayRef<LinalgOp> fusionCandidates,
const FusableOpDependencesTy &fusableDependences,
const std::set<unsigned> &fusedLoops) {
OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPoint(tiledOp);
DenseMap<unsigned, Range> fusedLoopsAndRanges;
for (unsigned loop : fusedLoops) {
ShapeDimension shapeDim = getShapeDefiningLoopRange(tiledOp, loop, true);
fusedLoopsAndRanges[loop] = getRangeFromOperandShape(
builder, tiledOp.getLoc(), shapeDim.shape, shapeDim.dimension);
}
SmallVector<LinalgOp, 1> fusedOps(fusionCandidates.size());
for (auto candidate : enumerate(llvm::reverse(fusionCandidates))) {
LinalgOp fusedOp = fuse(builder, candidate.value(), fusedLoopsAndRanges);
fusedOps[fusionCandidates.size() - candidate.index() - 1] = fusedOp;
builder.setInsertionPoint(fusedOp);
}
return fusedOps;
}
template <typename LoopType>
static Optional<TiledAndFusedLinalgOps>
tileAndFuseLinalgOpsImpl(PatternRewriter &rewriter, LinalgOp op,
tileAndFuseLinalgOpsImpl(OpBuilder &builder, ArrayRef<LinalgOp> ops,
const LinalgDependenceGraph &dependenceGraph,
const LinalgTilingOptions &tilingOptions,
const LinalgFusionOptions &fusionOptions) {
assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
// Some of the tiling options might not be supportable with tile and fuse.
// TODO: Support interchange with tile + fuse.
const LinalgTilingOptions &tilingOptions) {
if (ops.empty())
return llvm::None;
LinalgOp rootOp = ops.back();
for (auto op : enumerate(ops)) {
// TODO: Nothing in the fusion of sequence of ops is specific to
// buffers. This check can be removed after it is tested on tensors.
LinalgOp linalgOp = op.value();
if (!linalgOp.hasBufferSemantics()) {
linalgOp.emitError("tile and fuse only tested for buffer operation");
return llvm::None;
}
}
// TODO: Support interchange with tile + fuse. This might actually help do
// better fusion.
if (!tilingOptions.interchangeVector.empty()) {
op.emitError("unable to handle tile and fuse with interchange");
rootOp.emitError("unable to handle tile and fuse with interchange");
return llvm::None;
}
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(op);
ScopedContext scope(rewriter, op.getLoc());
OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPoint(rootOp);
ScopedContext scope(builder, rootOp.getLoc());
// Find all the producers.
FusableOpDependencesTy fusableDependences =
findAllFusableDependences(op, dependenceGraph, fusionOptions);
findAllFusableDependences(ops, dependenceGraph);
if (fusableDependences.empty())
return llvm::None;
// Enforce the convention that "tiling by zero" skips tiling a particular
// dimension. This convention is significantly simpler to handle instead of
// adjusting affine maps to account for missing dimensions.
auto nLoops = op.getNumLoops();
SmallVector<Value, 4> tileSizeVector =
tilingOptions.tileSizeComputationFunction(rewriter, op);
if (tileSizeVector.size() < nLoops) {
auto zero = std_constant_index(0);
tileSizeVector.append(nLoops - tileSizeVector.size(), zero);
}
TiledAndFusedLinalgOps ret;
// Find the loops that can be tiled and fused.
std::set<unsigned> tileFuseLoops =
collectTileAndFuseLoops(op, fusableDependences);
ret.fusedLoopDims = collectFusableLoops(ops, fusableDependences);
// If there are no fusable dependences or there are no tile+fusable loops,
// just return.
if (tileFuseLoops.empty()) {
if (ret.fusedLoopDims.empty()) {
return llvm::None;
}
// Get the tile sizes for the first and second tiling steps. For the first
// step the tile size are set to zero for the loops that arent
// fused. Similarly for the second step, the tile sizes are set to zero for
// the loops that are fused. For example, if for the following input
//
// ```
// linalg.add ins(%a, %b) outs(%c)
// linalg.matmul ins(%d, %c) outs(%e)
// ```
//
// if the tile sizes of the `{i, j, k}` loops where given as `{ti, tj, tk}`
// respectively, and since only `j` can be tiled and fused. The tile sizes
// would be `{0, t_j, 0}` for the first tiling that tiles just the fusable
// loops. The second tiling would be use tile sizes of `{t_i, 0, t_k}` to tile
// the tiled matmul generated by the first tiling step.
SmallVector<Value, 4> tileAndFuseSizes, tileSizes;
for (auto tileSize : enumerate(tileSizeVector)) {
auto zero = std_constant_index(0);
if (tileFuseLoops.count(tileSize.index())) {
tileAndFuseSizes.push_back(tileSize.value());
tileSizes.push_back(zero);
} else {
tileSizes.push_back(tileSize.value());
tileAndFuseSizes.push_back(zero);
}
}
// Tile for the loops that can be fused.
LinalgTilingOptions firstTilingOptions = tilingOptions;
firstTilingOptions.setTileSizes(tileAndFuseSizes);
Optional<TiledLinalgOp> firstTiledOp =
tileLinalgOp(rewriter, op, firstTilingOptions);
if (!firstTiledOp)
// Tile the fused loops in the last operation in the list.
SmallVector<Value, 4> tileSizeVector =
tilingOptions.tileSizeComputationFunction(builder, rootOp);
Optional<TiledLinalgOp> tiledRootOp = tileRootOperation(
builder, rootOp, tileSizeVector, tilingOptions, ret.fusedLoopDims);
if (!tiledRootOp) {
rootOp.emitError("failed to tile the fused loops");
return llvm::None;
ret.op = firstTiledOp->op;
ret.fusedLoops.assign(firstTiledOp->loops.begin(), firstTiledOp->loops.end());
rewriter.setInsertionPoint(ret.op);
// Fuse the operands.
for (auto dependence : fusableDependences) {
LinalgOp producerOp = cast<LinalgOp>(dependence.first);
unsigned producerIdx =
dependence.second.front().dependentOpView.operandIndex;
unsigned consumerIdx =
dependence.second.front().indexingOpView.operandIndex;
LinalgOp fusedOp = fuse(rewriter, producerOp,
producerOp.getOutputIndex(producerIdx).getValue(),
ret.op, consumerIdx);
ret.fusedProducers.push_back(fusedOp);
ret.originalProducers.push_back(producerOp);
}
if (!llvm::all_of(tileSizes, isZero)) {
// Tile the remaining loops of the root operation.
LinalgTilingOptions secondTilingOptions = tilingOptions;
// The distribution is done only for the tile+fused loops.
secondTilingOptions.distribution = llvm::None;
secondTilingOptions.setTileSizes(tileSizes);
Optional<TiledLinalgOp> secondTiledOp =
tileLinalgOp(rewriter, ret.op, secondTilingOptions);
if (!secondTiledOp)
return llvm::None;
ret.unfusedLoops.assign(secondTiledOp->loops.begin(),
secondTiledOp->loops.end());
rewriter.eraseOp(ret.op);
ret.op = secondTiledOp->op;
}
ret.op = tiledRootOp->op;
ret.fusedLoops.assign(tiledRootOp->loops.begin(), tiledRootOp->loops.end());
// Fuse the other operations into the fused inter-tile loops produced above.
ret.fusedProducers = fuseOperations(builder, ret.op, ops.drop_back(),
fusableDependences, ret.fusedLoopDims);
return ret;
}
Optional<TiledAndFusedLinalgOps>
mlir::linalg::tileAndFuseLinalgOps(PatternRewriter &rewriter, LinalgOp op,
mlir::linalg::tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef<LinalgOp> ops,
const LinalgDependenceGraph &dependenceGraph,
const LinalgTilingOptions &tilingOptions,
const LinalgFusionOptions &fusionOptions) {
const LinalgTilingOptions &tilingOptions) {
switch (tilingOptions.loopType) {
case LinalgTilingLoopType::Loops:
return tileAndFuseLinalgOpsImpl<scf::ForOp>(rewriter, op, dependenceGraph,
tilingOptions, fusionOptions);
return tileAndFuseLinalgOpsImpl<scf::ForOp>(builder, ops, dependenceGraph,
tilingOptions);
case LinalgTilingLoopType::ParallelLoops:
return tileAndFuseLinalgOpsImpl<scf::ParallelOp>(
rewriter, op, dependenceGraph, tilingOptions, fusionOptions);
builder, ops, dependenceGraph, tilingOptions);
default:;
}
return llvm::None;

View File

@ -165,17 +165,69 @@ LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite(
if (!linalgOp.hasBufferSemantics())
return failure();
DenseSet<Operation *> producers;
producers.insert(linalgOp);
for (auto dependence : dependenceGraph.getDependentOperations(linalgOp)) {
if (!fusionOptions.indicesToFuse.count(
dependence.indexingOpView.operandIndex))
continue;
if (isa<LinalgOp>(dependence.dependentOpView.op))
producers.insert(dependence.dependentOpView.op);
}
SmallVector<LinalgOp, 1> fusionOps;
for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie;
++it) {
auto producerLinalgOp = dyn_cast<LinalgOp>(&(*it));
if (producerLinalgOp && producers.count(producerLinalgOp))
fusionOps.push_back(producerLinalgOp);
}
fusionOps.push_back(linalgOp);
SmallVector<Value, 4> tileSizes =
tilingOptions.tileSizeComputationFunction(rewriter, op);
LinalgTilingOptions instanceTilingOptions = tilingOptions;
instanceTilingOptions.setTileSizes(tileSizes);
Optional<TiledAndFusedLinalgOps> tiledAndFusedOps = tileAndFuseLinalgOps(
rewriter, op, dependenceGraph, tilingOptions, fusionOptions);
rewriter, fusionOps, dependenceGraph, instanceTilingOptions);
if (!tiledAndFusedOps)
return failure();
// Tile the unfused loops;
SmallVector<Value, 4> unfusedLoopTileSizes;
Value zero = rewriter.create<ConstantIndexOp>(op->getLoc(), 0);
for (auto tileSize : enumerate(tileSizes)) {
if (tiledAndFusedOps->fusedLoopDims.count(tileSize.index()))
unfusedLoopTileSizes.push_back(zero);
else
unfusedLoopTileSizes.push_back(tileSize.value());
}
// Tile the loop only if there is a non-zero tile size.
if (unfusedLoopTileSizes.size() > linalgOp.getNumLoops())
unfusedLoopTileSizes.resize(linalgOp.getNumLoops());
if (llvm::any_of(unfusedLoopTileSizes, [](Value val) {
if (auto cst = val.getDefiningOp<ConstantIndexOp>())
return cst.getValue() != 0;
return true;
})) {
LinalgTilingOptions unfusedTilingOptions = tilingOptions;
unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes);
Optional<TiledLinalgOp> unfusedTiledOp =
tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions);
if (!unfusedTiledOp)
return failure();
rewriter.eraseOp(tiledAndFusedOps->op);
tiledAndFusedOps->op = unfusedTiledOp->op;
}
marker.replaceLinalgMarker(rewriter, tiledAndFusedOps->op.getOperation());
for (auto fusedOp : tiledAndFusedOps->fusedProducers) {
fusedOpMarker.replaceLinalgMarker(rewriter, fusedOp.getOperation());
}
for (auto origProducerOp : tiledAndFusedOps->originalProducers)
for (auto origProducerOp : ArrayRef<LinalgOp>(fusionOps).drop_back()) {
originalOpMarker.replaceLinalgMarker(rewriter,
origProducerOp.getOperation());
}
rewriter.updateRootInPlace(
op, [&]() { originalOpMarker.replaceLinalgMarker(rewriter, op); });
return success();

View File

@ -47,7 +47,9 @@ module {
// CHECK: %[[TILE_N_2:.+]] = affine.min #[[MAP2]](%[[IV1]])[%[[N_2]]]
// CHECK: %[[SV3:.+]] = subview %[[ARG2]][%[[IV0]], %[[IV1]]]
// CHECK-SAME: [%[[TILE_M_2]], %[[TILE_N_2]]]
// CHECK: linalg.fill(%[[SV3]], %[[CST]])
// CHECK: %[[SV3_2:.+]] = subview %[[ARG2]][%[[IV0]], %[[IV1]]]
// CHECK-SAME: [%[[TILE_M]], %[[TILE_N]]]
// CHECK: linalg.fill(%[[SV3_2]], %[[CST]])
// CHECK-SAME: __internal_linalg_transform__ = "after_basic_fusion_producer"
// CHECK: scf.for %[[IV2:.+]] = %[[C0]] to %[[K]] step %[[C16]] {
// CHECK: %[[TILE_K:.+]] = affine.min #[[MAP3]](%[[IV2]])[%[[K]]]
@ -109,9 +111,12 @@ module {
// CHECK: %[[TILE_N_2:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[N_2]]]
// CHECK: %[[SV2:.+]] = subview %[[ARG3]][0, %[[IV0]]]
// CHECK-SAME: [%[[M]], %[[TILE_N_2]]]
// CHECK: %[[K_2:.+]] = dim %[[ARG1]], %[[C0]]
// CHECK: %[[SV3:.+]] = subview %[[ARG1]][0, %[[IV0]]]
// CHECK-SAME: [%[[K]], %[[TILE_N]]]
// CHECK: linalg.copy(%[[SV3]], %[[SV1]])
// CHECK-SAME: [%[[K_2]], %[[TILE_N]]]
// CHECK: %[[SV3_2:.+]] = subview %[[ARG2]][0, %[[IV0]]]
// CHECK-SAME: [%[[K_2]], %[[TILE_N]]]
// CHECK: linalg.copy(%[[SV3]], %[[SV3_2]])
// CHECK-SAME: __internal_linalg_transform__ = "after_rhs_fusion_producer"
// CHECK-NOT: linalg.fill
// CHECK-DAG: %[[M_2:.+]] = dim %[[ARG0]], %[[C0]]
@ -186,11 +191,16 @@ module {
// CHECK: %[[N:.+]] = dim %[[ARG3]], %[[C1]]
// CHECK: %[[SV2:.+]] = subview %[[ARG3]][%[[IV0]], 0]
// CHECK-SAME: [%[[TILE_M_2]], %[[N]]]
// CHECK: %[[SV2_2:.+]] = subview %[[ARG3]][%[[IV0]], 0]
// CHECK-SAME: [%[[TILE_M]], %[[N]]]
// CHECK: %[[K_2:.+]] = dim %[[ARG0]], %[[C1]]
// CHECK: %[[SV3:.+]] = subview %[[ARG0]][%[[IV0]], 0]
// CHECK-SAME: [%[[TILE_M]], %[[K]]]
// CHECK: linalg.copy(%[[SV3]], %[[SV1]])
// CHECK-SAME: [%[[TILE_M]], %[[K_2]]]
// CHECK: %[[SV3_2:.+]] = subview %[[ARG1]][%[[IV0]], 0]
// CHECK-SAME: [%[[TILE_M]], %[[K_2]]]
// CHECK: linalg.copy(%[[SV3]], %[[SV3_2]])
// CHECK-SAME: __internal_linalg_transform__ = "after_two_operand_fusion_producer"
// CHECK: linalg.fill(%[[SV2]], %[[CST]])
// CHECK: linalg.fill(%[[SV2_2]], %[[CST]])
// CHECK-SAME: __internal_linalg_transform__ = "after_two_operand_fusion_producer"
// CHECK-DAG: %[[N_2:.+]] = dim %[[ARG2]], %[[C1]]
// CHECK: scf.parallel (%[[IV1:.+]]) =
@ -261,15 +271,18 @@ module {
// CHECK: %[[N:.+]] = dim %[[ARG4]], %[[C1]]
// CHECK: %[[SV2:.+]] = subview %[[ARG4]][%[[IV0]], 0]
// CHECK-SAME: [%[[TILE_M_2]], %[[N]]]
// CHECK: %[[K2_2:.+]] = dim %[[ARG1]], %[[C1]]
// CHECK: %[[K1:.+]] = dim %[[ARG0]], %[[C1]]
// CHECK: %[[SV3:.+]] = subview %[[ARG0]][%[[IV0]], 0]
// CHECK-SAME: [%[[TILE_M]], %[[K1]]]
// CHECK: %[[SV4:.+]] = subview %[[ARG1]][0, 0] [%[[K1]], %[[K2]]]
// CHECK: %[[SV4:.+]] = subview %[[ARG1]][0, 0] [%[[K1]], %[[K2_2]]]
// CHECK: %[[SV1_2:.+]] = subview %[[ARG2]][%[[IV0]], 0]
// CHECK-SAME: [%[[TILE_M]], %[[K2_2]]]
// CHECK: linalg.matmul
// CHECK-SAME: __internal_linalg_transform__ = "after_lhs_fusion_producer"
// CHECK-SAME: ins(%[[SV3]], %[[SV4]]
// CHECK-SAME: : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32, #[[MAP1]]>)
// CHECK-SAME: outs(%[[SV1]] : memref<?x?xf32, #[[MAP1]]>)
// CHECK-SAME: outs(%[[SV1_2]] : memref<?x?xf32, #[[MAP1]]>)
// CHECK-DAG: %[[N_2:.+]] = dim %[[ARG3]], %[[C1]]
// CHECK: scf.parallel (%[[IV1:.+]]) =
// CHECK-SAME: (%[[C0]]) to (%[[N_2]]) step (%[[C64]]) {
@ -413,3 +426,30 @@ module {
return
}
}
// -----
module {
func @basic_conv_fusion(%arg0: memref<?x?x?x?xf32>, %arg1: memref<?x?x?x?xf32>,
%arg2: memref<?x?x?x?xf32>) {
%cst = constant 0.000000e+00 : f32
linalg.fill(%arg2, %cst) : memref<?x?x?x?xf32>, f32
linalg.conv(%arg0, %arg1, %arg2) {
dilations = [1, 1], strides = [1, 1],
__internal_linalg_transform__ = "basic_fusion"} :
memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>
return
}
}
// CHECK: func @basic_conv_fusion
// CHECK: linalg.fill
// CHECK-SAME: __internal_linalg_transform__ = "after_basic_fusion_original"
// CHECK: scf.parallel (%{{.+}}, %{{.+}}, %{{.+}})
// CHECK-SAME: {
// CHECK: linalg.fill
// CHECK-SAME: __internal_linalg_transform__ = "after_basic_fusion_producer"
// CHECK: linalg.conv
// CHECK-SAME: __internal_linalg_transform__ = "after_basic_fusion"
// CHECK: }
// CHECK: linalg.conv
// CHECK-SAME: __internal_linalg_transform__ = "after_basic_fusion_original"

View File

@ -0,0 +1,133 @@
// RUN: mlir-opt -pass-pipeline="func(test-linalg-tile-and-fuse{tile-sizes=16,32,64}),canonicalize,cse" -split-input-file %s | FileCheck %s
module {
func @three_op_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
%arg2: memref<?xf32>, %arg3 : memref<?x?xf32>) {
%cst = constant 0.000000e+00 : f32
%c0 = constant 0 : index
%c1 = constant 1 : index
%d0 = dim %arg0, %c0 : memref<?x?xf32>
%d1 = dim %arg1, %c1 : memref<?x?xf32>
%0 = alloc(%d0, %d1) : memref<?x?xf32>
linalg.fill(%0, %cst) : memref<?x?xf32>, f32
linalg.matmul ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
outs(%0 : memref<?x?xf32>)
linalg.generic
{indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d1)>,
affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%0, %arg2 : memref<?x?xf32>, memref<?xf32>)
outs(%arg3 : memref<?x?xf32>) {
^bb0(%arg4 : f32, %arg5 : f32, %arg6 : f32) :
%5 = addf %arg4, %arg5 : f32
linalg.yield %5 : f32
}
return
}
}
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (d0 + s0)>
// CHECK: func @three_op_fusion
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref<?xf32>
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: memref<?x?xf32>
// CHECK: %[[TEMP:.+]] = alloc(%{{.*}}, %{{.*}}) : memref<?x?xf32>
// CHECK: scf.parallel (%[[IV0:.+]], %[[IV1:.+]]) = {{.*}} {
// CHECK-DAG: %[[SV_TEMP:.+]] = subview %[[TEMP]][%[[IV0]], %[[IV1]]]
// CHECK-DAG: %[[SV_ARG2:.+]] = subview %[[ARG2]][%[[IV1]]]
// CHECK-DAG: %[[SV_ARG3:.+]] = subview %[[ARG3]][%[[IV0]], %[[IV1]]]
// CHECK-DAG: %[[SV_ARG0:.+]] = subview %[[ARG0]][%[[IV0]], 0]
// CHECK-DAG: %[[SV_ARG1:.+]] = subview %[[ARG1]][0, %[[IV1]]]
// CHECK: linalg.fill(%[[SV_TEMP]], %{{.+}})
// CHECK: linalg.matmul
// CHECK-SAME: ins(%[[SV_ARG0]], %[[SV_ARG1]]
// CHECK-SAME: : memref<?x?xf32, #[[MAP2]]>, memref<?x?xf32, #[[MAP2]]>)
// CHECK-SAME: outs(%[[SV_TEMP]] : memref<?x?xf32, #[[MAP2]]>)
// CHECK: linalg.generic
// CHECK-SAME: ins(%[[SV_TEMP]], %[[SV_ARG2]]
// CHECK-SAME: : memref<?x?xf32, #[[MAP2]]>, memref<?xf32, #[[MAP3]]>)
// CHECK-SAME: outs(%[[SV_ARG3]] : memref<?x?xf32, #[[MAP2]]>)
// CHECK: scf.yield
// CHECK: }
// -----
module {
func @sequence_of_matmul(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
%arg2: memref<?x?xf32>, %arg3: memref<?x?xf32>,
%arg4: memref<?x?xf32>) {
%cst = constant 0.000000e+00 : f32
%c0 = constant 0 : index
%c1 = constant 1 : index
%m = dim %arg0, %c0 : memref<?x?xf32>
%n1 = dim %arg1, %c1 : memref<?x?xf32>
%n2 = dim %arg2, %c1 : memref<?x?xf32>
%n3 = dim %arg3, %c1 : memref<?x?xf32>
%0 = alloc(%m, %n1) : memref<?x?xf32>
%1 = alloc(%m, %n2) : memref<?x?xf32>
linalg.fill(%0, %cst) : memref<?x?xf32>, f32
linalg.matmul ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
outs(%0 : memref<?x?xf32>)
linalg.fill(%1, %cst) : memref<?x?xf32>, f32
linalg.matmul ins(%0, %arg2 : memref<?x?xf32>, memref<?x?xf32>)
outs(%1 : memref<?x?xf32>)
linalg.fill(%arg4, %cst) : memref<?x?xf32>, f32
linalg.matmul ins(%1, %arg3 : memref<?x?xf32>, memref<?x?xf32>)
outs(%arg4 : memref<?x?xf32>)
return
}
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
// CHECK: func @sequence_of_matmul
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref<?x?xf32>
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: memref<?x?xf32>
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: memref<?x?xf32>
// CHECK-DAG: %[[C0:.+]] = constant 0 : index
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
// CHECK-DAG: %[[C16:.+]] = constant 16 : index
// CHECK-DAG: %[[M:.+]] = dim %[[ARG0]], %[[C0]]
// CHECK-DAG: %[[N1:.+]] = dim %[[ARG1]], %[[C1]]
// CHECK-DAG: %[[N2:.+]] = dim %[[ARG2]], %[[C1]]
// CHECK: %[[ALLOC1:.+]] = alloc(%[[M]], %[[N1]])
// CHECK: %[[ALLOC2:.+]] = alloc(%[[M]], %[[N2]])
// CHECK: scf.parallel (%[[IV0:.+]]) = (%[[C0]]) to (%[[M]])
// CHECK-SAME: step (%[[C16]]) {
// CHECK: %[[TILE_M:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M]]]
// CHECK: %[[SV_ALLOC2:.+]] = subview %[[ALLOC2]][%[[IV0]], 0]
// CHECK-SAME: [%[[TILE_M]], %[[N2]]]
// CHECK: %[[M_2:.+]] = dim %[[ARG4]], %[[C0]]
// CHECK: %[[TILE_M_2:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M_2]]]
// CHECK: %[[N3:.+]] = dim %[[ARG4]], %[[C1]]
// CHECK: %[[SV_ARG4:.+]] = subview %[[ARG4]][%[[IV0]], 0]
// CHECK-SAME: [%[[TILE_M_2]], %[[N3]]]
// CHECK: %[[SV_ARG4_2:.+]] = subview %[[ARG4]][%[[IV0]], 0]
// CHECK-SAME: [%[[TILE_M]], %[[N3]]]
// CHECK: %[[SV_ALLOC1:.+]] = subview %[[ALLOC1]][%[[IV0]], 0]
// CHECK-SAME: [%[[TILE_M]], %[[N1]]]
// CHECK: %[[SV_ARG2:.+]] = subview %[[ARG2]][0, 0] [%[[N1]], %[[N2]]]
// CHECK: %[[N0:.+]] = dim %[[ARG0]], %[[C1]]
// CHECK: %[[SV_ARG0:.+]] = subview %[[ARG0]][%[[IV0]], 0]
// CHECK-SAME: [%[[TILE_M:.+]], %[[N0]]]
// CHECK: %[[SV_ARG1:.+]] = subview %[[ARG1]][0, 0] [%[[N0]], %[[N1]]]
// CHECK: linalg.fill(%[[SV_ALLOC1]], %{{.+}})
// CHECK: linalg.matmul ins(%[[SV_ARG0]], %[[SV_ARG1]]
// CHECK-SAME: : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32, #[[MAP1]]>)
// CHECK-SAME: outs(%[[SV_ALLOC1]] : memref<?x?xf32, #[[MAP1]]>)
// CHECK: linalg.fill(%[[SV_ALLOC2]], %{{.+}})
// CHECK: linalg.matmul ins(%[[SV_ALLOC1]], %[[SV_ARG2]]
// CHECK-SAME: : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32, #[[MAP1]]>)
// CHECK-SAME: outs(%[[SV_ALLOC2]] : memref<?x?xf32, #[[MAP1]]>)
// CHECK: linalg.fill(%[[SV_ARG4_2]], %{{.+}})
// CHECK: linalg.matmul ins(%[[SV_ALLOC2]], %[[ARG3]]
// CHECK-SAME: : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32>)
// CHECK-SAME: outs(%[[SV_ARG4]] : memref<?x?xf32, #[[MAP1]]>)
// CHECK: scf.yield
// CHECK: }

View File

@ -38,7 +38,8 @@ struct TestLinalgFusionTransforms
static void fillFusionPatterns(MLIRContext *context,
const LinalgDependenceGraph &dependenceGraph,
OwningRewritePatternList &patterns) {
patterns.insert<LinalgTileAndFusePattern<MatmulOp>>(
patterns.insert<LinalgTileAndFusePattern<MatmulOp>,
LinalgTileAndFusePattern<ConvOp>>(
context, dependenceGraph,
LinalgTilingOptions()
.setTileSizes({32, 64, 16})
@ -197,6 +198,44 @@ struct TestLinalgGreedyFusion
}
}
};
/// Pass to test tile and fuse of sequence of operations. Intended only for
/// testing.
struct TestLinalgTileAndFuseSequencePass
: public PassWrapper<TestLinalgTileAndFuseSequencePass, FunctionPass> {
TestLinalgTileAndFuseSequencePass() = default;
TestLinalgTileAndFuseSequencePass(
const TestLinalgTileAndFuseSequencePass &pass){};
ListOption<int64_t> tileSizes{
*this, "tile-sizes", llvm::cl::desc("Tile sizes to use for ops"),
llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<AffineDialect, linalg::LinalgDialect, scf::SCFDialect>();
}
void runOnFunction() override {
FuncOp funcOp = getOperation();
auto &blocks = funcOp.getBody().getBlocks();
if (!llvm::hasSingleElement(blocks)) {
return;
}
SmallVector<LinalgOp, 2> linalgOps =
llvm::to_vector<2>(blocks.front().getOps<LinalgOp>());
Aliases aliases;
LinalgDependenceGraph dependenceGraph(aliases, linalgOps);
OpBuilder builder(funcOp.getContext());
Optional<TiledAndFusedLinalgOps> tileAndFuseOps = tileAndFuseLinalgOps(
builder, linalgOps, dependenceGraph,
LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(
LinalgTilingLoopType::ParallelLoops));
if (!tileAndFuseOps)
return signalPassFailure();
for (auto op : linalgOps)
op.erase();
}
};
} // namespace
namespace mlir {
@ -211,5 +250,12 @@ void registerTestLinalgGreedyFusion() {
"test-linalg-greedy-fusion",
"Test Linalg fusion by applying a greedy test transformation.");
}
void registerTestLinalgTileAndFuseSequencePass() {
PassRegistration<TestLinalgTileAndFuseSequencePass>
testTileAndFuseSequencePass(
"test-linalg-tile-and-fuse",
"Test Linalg tiling and fusion of a sequence of Linalg operations.");
}
} // namespace test
} // namespace mlir

View File

@ -74,6 +74,7 @@ void registerTestLinalgCodegenStrategy();
void registerTestLinalgFusionTransforms();
void registerTestLinalgGreedyFusion();
void registerTestLinalgHoisting();
void registerTestLinalgTileAndFuseSequencePass();
void registerTestLinalgTransforms();
void registerTestLivenessPass();
void registerTestLoopFusion();
@ -141,6 +142,7 @@ void registerTestPasses() {
test::registerTestLinalgFusionTransforms();
test::registerTestLinalgGreedyFusion();
test::registerTestLinalgHoisting();
test::registerTestLinalgTileAndFuseSequencePass();
test::registerTestLinalgTransforms();
test::registerTestLivenessPass();
test::registerTestLoopFusion();