forked from OSchip/llvm-project
[mlir][Linalg] Fuse sequence of Linalg operation (on buffers)
Enhance the tile+fuse logic to allow fusing a sequence of operations. Make sure the value used to obtain tile shape is a SubViewOp/SubTensorOp. Current logic used to get the bounds of loop depends on the use of `getOrCreateRange` method on `SubViewOp` and `SubTensorOp`. Make sure that the value/dim used to compute the range is from such ops. This fix is a reasonable WAR, but a btter fix would be to make `getOrCreateRange` method be a method of `ViewInterface`. Differential Revision: https://reviews.llvm.org/D90991
This commit is contained in:
parent
4252f7773a
commit
e65a5e5b00
|
@ -37,14 +37,6 @@ struct TiledLinalgOp {
|
|||
SmallVector<Value, 4> tensorResults;
|
||||
};
|
||||
|
||||
struct TiledAndFusedLinalgOps {
|
||||
LinalgOp op;
|
||||
SmallVector<LinalgOp, 1> fusedProducers;
|
||||
SmallVector<LinalgOp, 1> originalProducers;
|
||||
SmallVector<Operation *, 4> fusedLoops;
|
||||
SmallVector<Operation *, 4> unfusedLoops;
|
||||
};
|
||||
|
||||
/// Populates patterns for vectorization of all ConvN-D ops.
|
||||
void populateConvVectorizationPatterns(
|
||||
MLIRContext *context, SmallVectorImpl<OwningRewritePatternList> &patterns,
|
||||
|
@ -73,14 +65,11 @@ void populateLinalgBufferizePatterns(MLIRContext *context,
|
|||
Optional<TiledLinalgOp> tileLinalgOp(OpBuilder &b, LinalgOp op,
|
||||
const LinalgTilingOptions &options);
|
||||
|
||||
/// Tile and fuse the `op` with its producers. The tile and fuse proceeds in
|
||||
/// three steps
|
||||
/// - Find tile loops that are fusable with its producer tile loops (a.k.a. tile
|
||||
/// + fuse loops).
|
||||
/// - Tile just these loops of the consumer (root operation) and fuse with
|
||||
/// the producer.
|
||||
/// - Tile again the tiled consumer operation produced above to do rest of
|
||||
/// the tiling specified by the `tilingOptions`.
|
||||
/// Fuse a sequence of linalg operations (`ops`) using tile-and-fuse. This
|
||||
/// proceeds as follows:
|
||||
/// - Find outer parallel loops in these ops that can be fused.
|
||||
/// - Tile fusable outer parallel loops of the last operation in the sequence.
|
||||
/// - Fuse the remaining operations with the tiled operation
|
||||
///
|
||||
/// For example, consider the sequence of matmul below
|
||||
///
|
||||
|
@ -107,36 +96,39 @@ Optional<TiledLinalgOp> tileLinalgOp(OpBuilder &b, LinalgOp op,
|
|||
/// : memref<256x32xf32> to memref<16x32xf32, #map0>
|
||||
/// %3 = subview %arg1[0, 0] [32, 32] [1, 1]
|
||||
/// : memref<32x32xf32> to memref<32x32xf32, #map1>
|
||||
/// %4 = subview %arg3[0, 0] [32, 32] [1, 1]
|
||||
/// : memref<32x32xf32> to memref<32x32xf32, #map1>
|
||||
/// linalg.matmul
|
||||
/// ins(%2, %3 : memref<16x32xf32, #map0>, memref<32x32xf32, #map1>)
|
||||
/// outs(%0 : memref<16x32xf32, #map0>)
|
||||
/// scf.parallel (%arg6) = (%c0) to (%c32) step (%c8) {
|
||||
/// scf.for %arg7 = %c0 to %c32 step %c4 {
|
||||
/// %4 = subview %0[0, %arg7] [16, 4] [1, 1]
|
||||
/// : memref<16x32xf32, #map0> to memref<16x4xf32, #map0>
|
||||
/// %5 = subview %arg3[%arg7, %arg6] [4, 8] [1, 1]
|
||||
/// : memref<32x32xf32> to memref<4x8xf32, #map0>
|
||||
/// %6 = subview %1[0, %arg6] [16, 8] [1, 1]
|
||||
/// : memref<16x32xf32, #map0> to memref<16x8xf32, #map0>
|
||||
/// linalg.matmul
|
||||
/// ins(%4, %5 : memref<16x4xf32, #map0>, memref<4x8xf32, #map0>)
|
||||
/// outs(%6 : memref<16x8xf32, #map0>)
|
||||
/// }
|
||||
/// scf.yield
|
||||
/// }
|
||||
/// scf.yield
|
||||
/// linalg.matmul
|
||||
/// ins(%0, %4 : memref<16x4xf32, #map0>, memref<4x8xf32, #map0>)
|
||||
/// outs(%1 : memref<16x8xf32, #map0>)
|
||||
/// }
|
||||
///
|
||||
/// The following tiling options are handled differently in tile+fuse (compared
|
||||
/// to tile only)
|
||||
/// `tilingOptions` are used to tile the corresponding operation in `ops` (the
|
||||
/// size of the former should be same as size of the latter. Based on how
|
||||
/// tile+fuse is implemented, the fused loops are generated based on the last
|
||||
/// operation in the sequence. For example, the tile sizes for the fused loops
|
||||
/// is obtained from `tilingOptions.back()`. The following tiling options are
|
||||
/// handled differently in tile+fuse (compared to tile only)
|
||||
/// - Interchange of the tiling loops is not supported right now.
|
||||
/// - Distribution is only done for the tile+fuse loops. The tiled loops
|
||||
/// generated by the second tiling is not distributed.
|
||||
/// - Only the fused loops are distributed.
|
||||
struct TiledAndFusedLinalgOps {
|
||||
/// Operation obtained by tiling the last operation in sequence of `ops`
|
||||
/// passed to `tileAndFuseLinalgOps`.
|
||||
LinalgOp op;
|
||||
/// The dimension of the loops that are fused.
|
||||
std::set<unsigned> fusedLoopDims;
|
||||
/// The generated fused operations (created within the fused loops).
|
||||
SmallVector<LinalgOp, 1> fusedProducers;
|
||||
/// The fused loop generated.
|
||||
SmallVector<Operation *, 4> fusedLoops;
|
||||
};
|
||||
Optional<TiledAndFusedLinalgOps>
|
||||
tileAndFuseLinalgOps(PatternRewriter &rewriter, LinalgOp op,
|
||||
tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef<LinalgOp> ops,
|
||||
const LinalgDependenceGraph &dependenceGraph,
|
||||
const LinalgTilingOptions &tilingOptions,
|
||||
const LinalgFusionOptions &fusionOptions);
|
||||
const LinalgTilingOptions &tilingOptions);
|
||||
|
||||
/// Interchanges the `iterator_types` and `iterator_maps` dimensions of `op`.
|
||||
/// This is an in-place transformation controlled by `interchangeVector`.
|
||||
|
|
|
@ -162,13 +162,24 @@ struct ShapeDimension {
|
|||
// guarantees at least one such dimension is found. If multiple candidates exist
|
||||
// they must agree by construction (i.e. have the same size) and we just return
|
||||
// the first one.
|
||||
static ShapeDimension getShapeDefiningLoopRange(LinalgOp op,
|
||||
unsigned loopDepth) {
|
||||
static ShapeDimension
|
||||
getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth,
|
||||
bool fromSubViewOpOnly = false) {
|
||||
auto maps = op.indexing_maps();
|
||||
// Iterate over the inputs and outputs in order.
|
||||
// Extract the subranges from the linearized ranges.
|
||||
SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers());
|
||||
for (auto en : llvm::enumerate(ios)) {
|
||||
// The method `getRangeFromOperandShape` requires using SubViewOp or
|
||||
// SubTensorOps. If the value isnt defined from there continue.
|
||||
// todo: The method should be adapted to get the values from
|
||||
// `ViewInterface`. The interface needs a `getOrCreateRanges` method which
|
||||
// currently returns a `linalg.range`. The fix here is to move this op to
|
||||
// `std` dialect and add the method to `ViewInterface`.
|
||||
if (fromSubViewOpOnly &&
|
||||
!isa_and_nonnull<SubViewOp, SubTensorOp>(en.value().getDefiningOp()))
|
||||
continue;
|
||||
|
||||
unsigned idx = en.index();
|
||||
auto map = maps[idx].cast<AffineMapAttr>().getValue();
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
|
@ -178,6 +189,9 @@ static ShapeDimension getShapeDefiningLoopRange(LinalgOp op,
|
|||
Value shape = en.value();
|
||||
SmallVector<Value, 8> shapeRanges(map.getNumResults(), nullptr);
|
||||
for (auto en2 : llvm::enumerate(map.getResults())) {
|
||||
auto dimExpr = en2.value().dyn_cast<AffineDimExpr>();
|
||||
if (!dimExpr)
|
||||
continue;
|
||||
if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange loopDepth: "
|
||||
<< loopDepth << "\n");
|
||||
|
@ -190,49 +204,18 @@ static ShapeDimension getShapeDefiningLoopRange(LinalgOp op,
|
|||
llvm_unreachable("Expect to be able to extract a shape defining loop range");
|
||||
}
|
||||
|
||||
/// Fuses the producer of `producerIdx` into the loop immediately enclosing
|
||||
/// `consumer`. This is achieved by "recomputing" the `producer` at the time it
|
||||
/// is needed just before the `consumer.
|
||||
///
|
||||
/// Depending on the type of `consumer.getShapedOperand(consumerIdx)`, there are
|
||||
/// 2 cases:
|
||||
/// 1. Buffer case: `producerIdx` is the index of the buffer in
|
||||
/// `producer.getOutputBuffers()`.
|
||||
/// 2. Tensor case: `producerIdx` is the index of the tensor in
|
||||
/// `producer.getResults()`.
|
||||
static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx,
|
||||
LinalgOp consumer, unsigned consumerIdx) {
|
||||
Operation *shapeProducingOp =
|
||||
consumer.getShapedOperand(consumerIdx).getDefiningOp();
|
||||
assert((isa<SubViewOp>(shapeProducingOp) ||
|
||||
isa<SubTensorOp>(shapeProducingOp)) &&
|
||||
"SubviewOp or SubTensorOp expected");
|
||||
|
||||
// loopToOperandRangesMaps are permutations-only by construction:
|
||||
// we can always identify a data dimension with a (at least one) loop
|
||||
// dimension.
|
||||
// TODO: extend this with range inference.
|
||||
AffineMap producerMap = producer.getOutputIndexingMap(producerIdx);
|
||||
LLVM_DEBUG(llvm::dbgs() << "Producer Idx: " << producerIdx
|
||||
<< ", producer map: " << producerMap << "\n");
|
||||
/// Fuse the producer by cloning the `producer`. The `fusedLoopsAndRanges`
|
||||
/// provides the loop range information for the fused loops. The rest are
|
||||
/// obtained from the producer itself, since they are not tiled + fused.
|
||||
static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
|
||||
const DenseMap<unsigned, Range> &fusedLoopsAndRanges) {
|
||||
|
||||
unsigned nPar = producer.getNumParallelLoops();
|
||||
unsigned nRed = producer.getNumReductionLoops();
|
||||
unsigned nWin = producer.getNumWindowLoops();
|
||||
SmallVector<Range, 8> loopRanges(nPar + nRed + nWin);
|
||||
|
||||
// Iterate over dimensions identified by the producer map for `producerIdx`.
|
||||
// This defines a subset of the loop ranges that we need to complete later.
|
||||
auto loc = consumer.getLoc();
|
||||
for (auto en : llvm::enumerate(producerMap.getResults())) {
|
||||
unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition();
|
||||
loopRanges[posInProducerLoop] =
|
||||
isa<SubViewOp>(shapeProducingOp)
|
||||
? cast<SubViewOp>(shapeProducingOp)
|
||||
.getOrCreateRanges(b, loc)[en.index()]
|
||||
: cast<SubTensorOp>(shapeProducingOp)
|
||||
.getOrCreateRanges(b, loc)[en.index()];
|
||||
}
|
||||
for (auto fusedLoops : fusedLoopsAndRanges)
|
||||
loopRanges[fusedLoops.first] = fusedLoops.second;
|
||||
|
||||
// Iterate over all dimensions. For the dimensions not identified by the
|
||||
// producer map for `producerIdx`, we need to explicitly compute the shape
|
||||
|
@ -250,7 +233,45 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx,
|
|||
}
|
||||
}
|
||||
|
||||
return cloneWithLoopRanges(b, loc, producer, loopRanges);
|
||||
return cloneWithLoopRanges(b, producer.getLoc(), producer, loopRanges);
|
||||
}
|
||||
|
||||
/// Get the loop range for a dimension `dim` based on the `shapedOperand`. It is
|
||||
/// expected to be defined by a subview op or a subtensor op.
|
||||
static Range getRangeFromOperandShape(OpBuilder &b, Location loc,
|
||||
Value shapedOperand, unsigned dim) {
|
||||
Operation *shapeProducingOp = shapedOperand.getDefiningOp();
|
||||
if (auto subViewOp = dyn_cast<SubViewOp>(shapeProducingOp))
|
||||
return subViewOp.getOrCreateRanges(b, loc)[dim];
|
||||
if (auto subTensorOp = dyn_cast<SubTensorOp>(shapeProducingOp))
|
||||
return subTensorOp.getOrCreateRanges(b, loc)[dim];
|
||||
llvm_unreachable("SubviewOp or SubTensorOp expected");
|
||||
}
|
||||
|
||||
/// Fuses the producer of `producerIdx` into the loop immediately enclosing
|
||||
/// `consumer`. This is achieved by "recomputing" the `producer` at the time it
|
||||
/// is needed just before the `consumer.
|
||||
///
|
||||
/// Depending on the type of `consumer.getShapedOperand(consumerIdx)`, there are
|
||||
/// 2 cases:
|
||||
/// 1. Buffer case: `producerIdx` is the index of the buffer in
|
||||
/// `producer.getOutputBuffers()`.
|
||||
/// 2. Tensor case: `producerIdx` is the index of the tensor in
|
||||
/// `producer.getResults()`.
|
||||
static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx,
|
||||
LinalgOp consumer, unsigned consumerIdx) {
|
||||
AffineMap producerMap = producer.getOutputIndexingMap(producerIdx);
|
||||
LLVM_DEBUG(llvm::dbgs() << "Producer Idx: " << producerIdx
|
||||
<< ", producer map: " << producerMap << "\n");
|
||||
DenseMap<unsigned, Range> fusedLoopsAndRanges;
|
||||
Location loc = consumer.getLoc();
|
||||
Value shapedOperand = consumer.getShapedOperand(consumerIdx);
|
||||
for (auto en : llvm::enumerate(producerMap.getResults())) {
|
||||
unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition();
|
||||
fusedLoopsAndRanges[posInProducerLoop] =
|
||||
getRangeFromOperandShape(b, loc, shapedOperand, en.index());
|
||||
}
|
||||
return fuse(b, producer, fusedLoopsAndRanges);
|
||||
}
|
||||
|
||||
// Encode structural fusion safety preconditions.
|
||||
|
@ -525,6 +546,69 @@ using FusableOpDependencesTy = llvm::MapVector<
|
|||
Operation *,
|
||||
SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 1>>;
|
||||
|
||||
/// Returns the mapping from iterations in the consumer that write to the same
|
||||
/// location as the iterations in the producer. To do so use
|
||||
/// - indexing map of the fused view in the consumer : consumerIndexMap
|
||||
/// - indexing map of the fused view in the producer : producerIndexMap
|
||||
/// consumerLoopToProducerLoop =
|
||||
/// inverse(producerIndexMap).compose(consumerIndexMap)
|
||||
static Optional<AffineMap> getConsumerLoopToProducerLoopMap(
|
||||
LinalgDependenceGraph::LinalgDependenceGraphElem dependence) {
|
||||
auto producer = cast<LinalgOp>(dependence.dependentOpView.op);
|
||||
AffineMap producerIndexingMap =
|
||||
producer.getIndexingMap(dependence.dependentOpView.operandIndex);
|
||||
auto consumer = cast<LinalgOp>(dependence.indexingOpView.op);
|
||||
AffineMap consumerIndexingMap =
|
||||
consumer.getIndexingMap(dependence.indexingOpView.operandIndex);
|
||||
|
||||
AffineMap prunedProducerIndexingMap = pruneReductionDimsFromMap(
|
||||
producer.iterator_types().getValue(), producerIndexingMap);
|
||||
if (!prunedProducerIndexingMap.isPermutation())
|
||||
return None;
|
||||
|
||||
if (consumerIndexingMap.getNumResults() !=
|
||||
prunedProducerIndexingMap.getNumResults())
|
||||
return None;
|
||||
|
||||
LLVM_DEBUG({
|
||||
llvm::dbgs() << "\t producerMap : ";
|
||||
producerIndexingMap.print(llvm::dbgs());
|
||||
llvm::dbgs() << " pruned : ";
|
||||
prunedProducerIndexingMap.print(llvm::dbgs());
|
||||
llvm::dbgs() << "\n";
|
||||
llvm::dbgs() << "\t consumerMap : ";
|
||||
consumerIndexingMap.print(llvm::dbgs());
|
||||
llvm::dbgs() << "\n";
|
||||
});
|
||||
|
||||
AffineMap invProducerIndexMap = inversePermutation(prunedProducerIndexingMap);
|
||||
if (!invProducerIndexMap)
|
||||
return None;
|
||||
|
||||
return invProducerIndexMap.compose(consumerIndexingMap);
|
||||
}
|
||||
|
||||
/// Given a projected permutation `map`, returns true if the map changes the
|
||||
/// order in which the fused loop dimension appear.
|
||||
static bool doesTransposeAccess(AffineMap map,
|
||||
const std::set<unsigned> &fusableLoops) {
|
||||
Optional<unsigned> lastFusableLoop;
|
||||
for (unsigned pos : llvm::map_range(map.getResults(), [](AffineExpr expr) {
|
||||
return expr.cast<AffineDimExpr>().getPosition();
|
||||
})) {
|
||||
if (!fusableLoops.count(pos))
|
||||
continue;
|
||||
if (!lastFusableLoop) {
|
||||
lastFusableLoop = pos;
|
||||
continue;
|
||||
}
|
||||
if (pos <= lastFusableLoop.getValue())
|
||||
return true;
|
||||
lastFusableLoop = pos;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Returns the positions of the loop in `op` that can be tiled based on the
|
||||
/// operations that are to be fused with it. For example, in a
|
||||
///
|
||||
|
@ -538,13 +622,7 @@ using FusableOpDependencesTy = llvm::MapVector<
|
|||
/// 2. Of the parallel loops only some can be fused. Only those loops can be
|
||||
/// fused such where the fusable loops iteration space only touches one tile
|
||||
/// of the fused operation. This is because the producer (which is writing
|
||||
/// the fused subview) has update semantics. To compute this,
|
||||
/// a. Find the mapping from iterations in the consumer that write to the
|
||||
/// same location as the iterations in the producer. To do so use
|
||||
/// - indexing map of the fused view in the consumer : consumerIndexMap
|
||||
/// - indexing map of the fused view in the producer : producerIndexMap
|
||||
/// consumerLoopToProducerLoop =
|
||||
/// inverse(producerIndexMap).compose(consumerIndexMap)
|
||||
/// the fused subview) has update semantics.
|
||||
///
|
||||
/// Since an inverse computation is needed, we need to consider the projection
|
||||
/// of the producerIndexMap w.r.t the parallel loops. The actual fusable loops
|
||||
|
@ -582,8 +660,9 @@ using FusableOpDependencesTy = llvm::MapVector<
|
|||
/// submap with only parallel loops = affine_map<(i, j) -> (j)>
|
||||
/// Fused dimensions : j
|
||||
static std::set<unsigned>
|
||||
collectTileAndFuseLoops(LinalgOp op,
|
||||
const FusableOpDependencesTy &fusableDependences) {
|
||||
collectFusableLoops(ArrayRef<LinalgOp> ops,
|
||||
const FusableOpDependencesTy &fusableDependences) {
|
||||
assert(!ops.empty());
|
||||
auto getNumOuterParallelLoops = [](LinalgOp linalgOp) {
|
||||
return linalgOp.iterator_types()
|
||||
.getValue()
|
||||
|
@ -594,88 +673,57 @@ collectTileAndFuseLoops(LinalgOp op,
|
|||
.size();
|
||||
};
|
||||
|
||||
LLVM_DEBUG({
|
||||
llvm::dbgs() << "Op : ";
|
||||
op.getOperation()->print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
|
||||
llvm::dbgs() << "\n";
|
||||
});
|
||||
|
||||
size_t numOuterParallelLoops = getNumOuterParallelLoops(op);
|
||||
for (auto dependence : fusableDependences) {
|
||||
linalg::LinalgOp producer = cast<linalg::LinalgOp>(dependence.first);
|
||||
size_t numOuterParallelLoops = getNumOuterParallelLoops(ops.back());
|
||||
for (auto op : ops.drop_back()) {
|
||||
numOuterParallelLoops =
|
||||
std::min(numOuterParallelLoops, getNumOuterParallelLoops(producer));
|
||||
std::min(numOuterParallelLoops, getNumOuterParallelLoops(op));
|
||||
}
|
||||
|
||||
std::set<unsigned> fusableLoops;
|
||||
auto range = llvm::seq<unsigned>(0, numOuterParallelLoops);
|
||||
fusableLoops.insert(range.begin(), range.end());
|
||||
for (auto dependence : fusableDependences) {
|
||||
LLVM_DEBUG({
|
||||
llvm::dbgs() << "\t fusable :";
|
||||
for (unsigned i : fusableLoops)
|
||||
llvm::dbgs() << " " << i;
|
||||
llvm::dbgs() << "\n";
|
||||
});
|
||||
linalg::LinalgOp producer = cast<linalg::LinalgOp>(dependence.first);
|
||||
|
||||
assert(!dependence.second.empty() &&
|
||||
"unexpected producer but not dependences");
|
||||
AffineMap producerIndexingMap = producer.getIndexingMap(
|
||||
dependence.second.front().dependentOpView.operandIndex);
|
||||
AffineMap prunedProducerIndexingMap = pruneReductionDimsFromMap(
|
||||
producer.iterator_types().getValue(), producerIndexingMap);
|
||||
if (!prunedProducerIndexingMap.isPermutation())
|
||||
return {};
|
||||
for (auto op : reverse(ops)) {
|
||||
for (auto dependence : fusableDependences.lookup(op)) {
|
||||
LLVM_DEBUG({
|
||||
llvm::dbgs() << "\t fusable :";
|
||||
for (unsigned i : fusableLoops)
|
||||
llvm::dbgs() << " " << i;
|
||||
llvm::dbgs() << "\n";
|
||||
});
|
||||
|
||||
AffineMap consumerIndexingMap = op.getIndexingMap(
|
||||
dependence.second.front().indexingOpView.operandIndex);
|
||||
if (consumerIndexingMap.getNumResults() !=
|
||||
prunedProducerIndexingMap.getNumResults())
|
||||
return {};
|
||||
Optional<AffineMap> consumerLoopToProducerLoop =
|
||||
getConsumerLoopToProducerLoopMap(dependence);
|
||||
if (!consumerLoopToProducerLoop) {
|
||||
op.emitRemark("failed to get map from consumer loop to producer loop");
|
||||
return {};
|
||||
}
|
||||
// todo: This condition is only an implementation limitation. When fusing
|
||||
// the operation, if the accesses in the producer/consumer are transposes
|
||||
// of each other, the loop bounds for the tiled producer can be
|
||||
// manipulated accordingly. This requires some additional bookkeeping in
|
||||
// the implementation of tile+fuse that is defered to later.
|
||||
if (doesTransposeAccess(*consumerLoopToProducerLoop, fusableLoops)) {
|
||||
op.emitRemark("unhandled fusion when fusion requires permutation");
|
||||
return {};
|
||||
}
|
||||
|
||||
LLVM_DEBUG({
|
||||
llvm::dbgs() << "\t producerMap : ";
|
||||
producerIndexingMap.print(llvm::dbgs());
|
||||
llvm::dbgs() << " pruned : ";
|
||||
prunedProducerIndexingMap.print(llvm::dbgs());
|
||||
llvm::dbgs() << "\n";
|
||||
llvm::dbgs() << "\t consumerMap : ";
|
||||
consumerIndexingMap.print(llvm::dbgs());
|
||||
llvm::dbgs() << "\n";
|
||||
});
|
||||
|
||||
AffineMap invProducerIndexMap =
|
||||
inversePermutation(prunedProducerIndexingMap);
|
||||
if (!invProducerIndexMap)
|
||||
return {};
|
||||
|
||||
AffineMap consumerLoopToProducerLoop =
|
||||
invProducerIndexMap.compose(consumerIndexingMap);
|
||||
|
||||
LLVM_DEBUG({
|
||||
llvm::dbgs() << "\t consumerLoopToProducerLoop : ";
|
||||
consumerLoopToProducerLoop.print(llvm::dbgs());
|
||||
});
|
||||
|
||||
std::set<unsigned> candidates;
|
||||
for (AffineExpr expr : consumerLoopToProducerLoop.getResults()) {
|
||||
AffineDimExpr dimExpr = expr.dyn_cast<AffineDimExpr>();
|
||||
if (!dimExpr)
|
||||
continue;
|
||||
unsigned position = dimExpr.getPosition();
|
||||
if (fusableLoops.count(position))
|
||||
candidates.insert(position);
|
||||
std::set<unsigned> candidates;
|
||||
for (AffineExpr expr : consumerLoopToProducerLoop->getResults()) {
|
||||
unsigned position = expr.cast<AffineDimExpr>().getPosition();
|
||||
if (fusableLoops.count(position))
|
||||
candidates.insert(position);
|
||||
}
|
||||
LLVM_DEBUG({
|
||||
llvm::dbgs() << "\t candidates :";
|
||||
for (unsigned i : candidates)
|
||||
llvm::dbgs() << " " << i;
|
||||
llvm::dbgs() << "\n";
|
||||
});
|
||||
if (candidates.empty())
|
||||
return {};
|
||||
std::swap(candidates, fusableLoops);
|
||||
}
|
||||
LLVM_DEBUG({
|
||||
llvm::dbgs() << "\t candidates :";
|
||||
for (unsigned i : candidates)
|
||||
llvm::dbgs() << " " << i;
|
||||
llvm::dbgs() << "\n";
|
||||
});
|
||||
if (candidates.empty())
|
||||
return {};
|
||||
std::swap(candidates, fusableLoops);
|
||||
}
|
||||
|
||||
return fusableLoops;
|
||||
|
@ -683,60 +731,69 @@ collectTileAndFuseLoops(LinalgOp op,
|
|||
|
||||
/// Find all dependences that are to be fusable.
|
||||
static FusableOpDependencesTy
|
||||
findAllFusableDependences(LinalgOp op,
|
||||
const LinalgDependenceGraph &dependenceGraph,
|
||||
const LinalgFusionOptions &fusionOptions) {
|
||||
findAllFusableDependences(ArrayRef<LinalgOp> ops,
|
||||
const LinalgDependenceGraph &dependenceGraph) {
|
||||
FusableOpDependencesTy fusableDependences;
|
||||
// TODO: Currently fusion would not be legal if the fusable dependence is to
|
||||
// the same producer but different indexing map in the consumer. Fix this, but
|
||||
// in the meanwhile disallow such a fusion.
|
||||
DenseMap<Operation *, AffineMap> fusedProducerIndexingMap;
|
||||
for (auto operandIndex : fusionOptions.indicesToFuse) {
|
||||
auto fusableDependence =
|
||||
findFusableProducer(op, operandIndex, dependenceGraph);
|
||||
if (!fusableDependence)
|
||||
return FusableOpDependencesTy{};
|
||||
LinalgOp producerOp = cast<LinalgOp>(fusableDependence->dependentOpView.op);
|
||||
// Do not fuse dependences that are to operations not in the same basic
|
||||
// block. This avoid moving fused operations across loops that might
|
||||
// themselves carry dependency making the fusion illegal.
|
||||
if (producerOp.getOperation()->getBlock() !=
|
||||
op.getOperation()->getBlock()) {
|
||||
op.emitRemark("unhandled fusion of ops in different basic blocks");
|
||||
return FusableOpDependencesTy{};
|
||||
}
|
||||
// Make sure that the indexing map of the view used for fusion in the
|
||||
// producer is a projected permutation.
|
||||
unsigned producerIdx = fusableDependence->dependentOpView.operandIndex;
|
||||
AffineMap producerMap = producerOp.getIndexingMap(producerIdx);
|
||||
if (!producerMap.isProjectedPermutation()) {
|
||||
op.emitRemark("unhandled non permutation indexing map for fused view in "
|
||||
"producer for operand at index ")
|
||||
<< operandIndex;
|
||||
return FusableOpDependencesTy{};
|
||||
}
|
||||
for (LinalgOp op : reverse(ops)) {
|
||||
for (auto operandIndex :
|
||||
llvm::seq<unsigned>(0, op.getNumInputsAndOutputBuffers())) {
|
||||
Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
|
||||
fusableDependence =
|
||||
findFusableProducer(op, operandIndex, dependenceGraph);
|
||||
if (!fusableDependence)
|
||||
continue;
|
||||
LinalgOp producerOp =
|
||||
cast<LinalgOp>(fusableDependence->dependentOpView.op);
|
||||
// Do not fuse dependences that are to operations not in the same basic
|
||||
// block. This avoid moving fused operations across loops that might
|
||||
// themselves carry dependency making the fusion illegal.
|
||||
if (producerOp.getOperation()->getBlock() !=
|
||||
op.getOperation()->getBlock()) {
|
||||
op.emitRemark("unhandled fusion of ops in different basic blocks");
|
||||
return FusableOpDependencesTy{};
|
||||
}
|
||||
// Make sure that the indexing map of the view used for fusion in the
|
||||
// producer is a projected permutation.
|
||||
unsigned producerIdx = fusableDependence->dependentOpView.operandIndex;
|
||||
AffineMap producerMap = producerOp.getIndexingMap(producerIdx);
|
||||
if (!producerMap.isProjectedPermutation()) {
|
||||
op.emitRemark(
|
||||
"unhandled non permutation indexing map for fused view in "
|
||||
"producer for operand at index ")
|
||||
<< operandIndex;
|
||||
return FusableOpDependencesTy{};
|
||||
}
|
||||
|
||||
unsigned consumerIdx = fusableDependence->indexingOpView.operandIndex;
|
||||
AffineMap consumerMap = op.getIndexingMap(consumerIdx);
|
||||
if (!consumerMap.isProjectedPermutation()) {
|
||||
op.emitRemark(
|
||||
"unhandled case where indexing map for fused view in the consumer is "
|
||||
"not a projected permutation while fusing at index ")
|
||||
<< operandIndex;
|
||||
return FusableOpDependencesTy{};
|
||||
}
|
||||
unsigned consumerIdx = fusableDependence->indexingOpView.operandIndex;
|
||||
AffineMap consumerMap = op.getIndexingMap(consumerIdx);
|
||||
if (!consumerMap.isProjectedPermutation()) {
|
||||
op.emitRemark(
|
||||
"unhandled case where indexing map for fused view in the consumer "
|
||||
"is "
|
||||
"not a projected permuration while fusing at index ")
|
||||
<< operandIndex;
|
||||
return FusableOpDependencesTy{};
|
||||
}
|
||||
|
||||
// Check if the producer is already a fusion candidate. Cannot fuse this
|
||||
// dependence if it has a different indexing map when used in the consumer.
|
||||
if (fusedProducerIndexingMap.count(producerOp.getOperation()) &&
|
||||
fusedProducerIndexingMap[producerOp.getOperation()] != consumerMap) {
|
||||
op.emitRemark("unhandled fusion to the same producer but with different "
|
||||
"indexing maps");
|
||||
return FusableOpDependencesTy{};
|
||||
}
|
||||
fusedProducerIndexingMap[producerOp.getOperation()] = consumerMap;
|
||||
// Check if the producer is already a fusion candidate. Cannot fuse this
|
||||
// dependence if it has a different indexing map when used in the
|
||||
// consumer.
|
||||
if (fusedProducerIndexingMap.count(producerOp.getOperation()) &&
|
||||
fusedProducerIndexingMap[producerOp.getOperation()] != consumerMap) {
|
||||
op.emitRemark(
|
||||
"unhandled fusion to the same producer but with different "
|
||||
"indexing maps");
|
||||
return FusableOpDependencesTy{};
|
||||
}
|
||||
fusedProducerIndexingMap[producerOp.getOperation()] = consumerMap;
|
||||
|
||||
fusableDependences[producerOp.getOperation()].push_back(*fusableDependence);
|
||||
fusableDependences[producerOp.getOperation()].push_back(
|
||||
*fusableDependence);
|
||||
}
|
||||
}
|
||||
return fusableDependences;
|
||||
}
|
||||
|
@ -747,136 +804,120 @@ static bool isZero(Value v) {
|
|||
return false;
|
||||
}
|
||||
|
||||
/// Tile the fused loops in the root operation, by setting the tile sizes for
|
||||
/// all other loops to zero (those will be tiled later).
|
||||
static Optional<TiledLinalgOp> tileRootOperation(
|
||||
OpBuilder &builder, LinalgOp op, ArrayRef<Value> tileSizeVector,
|
||||
const LinalgTilingOptions &options, const std::set<unsigned> &fusedLoops) {
|
||||
SmallVector<Value, 4> tileSizes(tileSizeVector.begin(), tileSizeVector.end());
|
||||
auto zero = std_constant_index(0);
|
||||
for (unsigned i = 0, e = tileSizes.size(); i != e; ++i)
|
||||
if (!fusedLoops.count(i))
|
||||
tileSizes[i] = zero;
|
||||
LinalgTilingOptions tileFusedLoopsOptions = options;
|
||||
tileFusedLoopsOptions.setTileSizes(tileSizes);
|
||||
return tileLinalgOp(builder, op, tileFusedLoopsOptions);
|
||||
}
|
||||
|
||||
/// Fuse the operations in `fusionCandidates` with `tiledOp`. Latter is expected
|
||||
/// to be a tiled operation such that it is valid to fuse all operations in
|
||||
/// `fusionCandidates`, i.e. move the operation within the inter-tile loops of
|
||||
/// `tiledOp`.
|
||||
static SmallVector<LinalgOp, 1>
|
||||
fuseOperations(OpBuilder &builder, LinalgOp tiledOp,
|
||||
ArrayRef<LinalgOp> fusionCandidates,
|
||||
const FusableOpDependencesTy &fusableDependences,
|
||||
const std::set<unsigned> &fusedLoops) {
|
||||
OpBuilder::InsertionGuard guard(builder);
|
||||
builder.setInsertionPoint(tiledOp);
|
||||
DenseMap<unsigned, Range> fusedLoopsAndRanges;
|
||||
for (unsigned loop : fusedLoops) {
|
||||
ShapeDimension shapeDim = getShapeDefiningLoopRange(tiledOp, loop, true);
|
||||
fusedLoopsAndRanges[loop] = getRangeFromOperandShape(
|
||||
builder, tiledOp.getLoc(), shapeDim.shape, shapeDim.dimension);
|
||||
}
|
||||
SmallVector<LinalgOp, 1> fusedOps(fusionCandidates.size());
|
||||
for (auto candidate : enumerate(llvm::reverse(fusionCandidates))) {
|
||||
LinalgOp fusedOp = fuse(builder, candidate.value(), fusedLoopsAndRanges);
|
||||
fusedOps[fusionCandidates.size() - candidate.index() - 1] = fusedOp;
|
||||
builder.setInsertionPoint(fusedOp);
|
||||
}
|
||||
return fusedOps;
|
||||
}
|
||||
|
||||
template <typename LoopType>
|
||||
static Optional<TiledAndFusedLinalgOps>
|
||||
tileAndFuseLinalgOpsImpl(PatternRewriter &rewriter, LinalgOp op,
|
||||
tileAndFuseLinalgOpsImpl(OpBuilder &builder, ArrayRef<LinalgOp> ops,
|
||||
const LinalgDependenceGraph &dependenceGraph,
|
||||
const LinalgTilingOptions &tilingOptions,
|
||||
const LinalgFusionOptions &fusionOptions) {
|
||||
assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
|
||||
// Some of the tiling options might not be supportable with tile and fuse.
|
||||
// TODO: Support interchange with tile + fuse.
|
||||
const LinalgTilingOptions &tilingOptions) {
|
||||
if (ops.empty())
|
||||
return llvm::None;
|
||||
LinalgOp rootOp = ops.back();
|
||||
for (auto op : enumerate(ops)) {
|
||||
// TODO: Nothing in the fusion of sequence of ops is specific to
|
||||
// buffers. This check can be removed after it is tested on tensors.
|
||||
LinalgOp linalgOp = op.value();
|
||||
if (!linalgOp.hasBufferSemantics()) {
|
||||
linalgOp.emitError("tile and fuse only tested for buffer operation");
|
||||
return llvm::None;
|
||||
}
|
||||
}
|
||||
// TODO: Support interchange with tile + fuse. This might actually help do
|
||||
// better fusion.
|
||||
if (!tilingOptions.interchangeVector.empty()) {
|
||||
op.emitError("unable to handle tile and fuse with interchange");
|
||||
rootOp.emitError("unable to handle tile and fuse with interchange");
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
OpBuilder::InsertionGuard g(rewriter);
|
||||
rewriter.setInsertionPoint(op);
|
||||
ScopedContext scope(rewriter, op.getLoc());
|
||||
OpBuilder::InsertionGuard guard(builder);
|
||||
builder.setInsertionPoint(rootOp);
|
||||
ScopedContext scope(builder, rootOp.getLoc());
|
||||
|
||||
// Find all the producers.
|
||||
FusableOpDependencesTy fusableDependences =
|
||||
findAllFusableDependences(op, dependenceGraph, fusionOptions);
|
||||
findAllFusableDependences(ops, dependenceGraph);
|
||||
if (fusableDependences.empty())
|
||||
return llvm::None;
|
||||
|
||||
// Enforce the convention that "tiling by zero" skips tiling a particular
|
||||
// dimension. This convention is significantly simpler to handle instead of
|
||||
// adjusting affine maps to account for missing dimensions.
|
||||
auto nLoops = op.getNumLoops();
|
||||
SmallVector<Value, 4> tileSizeVector =
|
||||
tilingOptions.tileSizeComputationFunction(rewriter, op);
|
||||
if (tileSizeVector.size() < nLoops) {
|
||||
auto zero = std_constant_index(0);
|
||||
tileSizeVector.append(nLoops - tileSizeVector.size(), zero);
|
||||
}
|
||||
|
||||
TiledAndFusedLinalgOps ret;
|
||||
|
||||
// Find the loops that can be tiled and fused.
|
||||
std::set<unsigned> tileFuseLoops =
|
||||
collectTileAndFuseLoops(op, fusableDependences);
|
||||
ret.fusedLoopDims = collectFusableLoops(ops, fusableDependences);
|
||||
|
||||
// If there are no fusable dependences or there are no tile+fusable loops,
|
||||
// just return.
|
||||
if (tileFuseLoops.empty()) {
|
||||
if (ret.fusedLoopDims.empty()) {
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
// Get the tile sizes for the first and second tiling steps. For the first
|
||||
// step the tile size are set to zero for the loops that arent
|
||||
// fused. Similarly for the second step, the tile sizes are set to zero for
|
||||
// the loops that are fused. For example, if for the following input
|
||||
//
|
||||
// ```
|
||||
// linalg.add ins(%a, %b) outs(%c)
|
||||
// linalg.matmul ins(%d, %c) outs(%e)
|
||||
// ```
|
||||
//
|
||||
// if the tile sizes of the `{i, j, k}` loops where given as `{ti, tj, tk}`
|
||||
// respectively, and since only `j` can be tiled and fused. The tile sizes
|
||||
// would be `{0, t_j, 0}` for the first tiling that tiles just the fusable
|
||||
// loops. The second tiling would be use tile sizes of `{t_i, 0, t_k}` to tile
|
||||
// the tiled matmul generated by the first tiling step.
|
||||
SmallVector<Value, 4> tileAndFuseSizes, tileSizes;
|
||||
for (auto tileSize : enumerate(tileSizeVector)) {
|
||||
auto zero = std_constant_index(0);
|
||||
if (tileFuseLoops.count(tileSize.index())) {
|
||||
tileAndFuseSizes.push_back(tileSize.value());
|
||||
tileSizes.push_back(zero);
|
||||
} else {
|
||||
tileSizes.push_back(tileSize.value());
|
||||
tileAndFuseSizes.push_back(zero);
|
||||
}
|
||||
}
|
||||
|
||||
// Tile for the loops that can be fused.
|
||||
LinalgTilingOptions firstTilingOptions = tilingOptions;
|
||||
firstTilingOptions.setTileSizes(tileAndFuseSizes);
|
||||
Optional<TiledLinalgOp> firstTiledOp =
|
||||
tileLinalgOp(rewriter, op, firstTilingOptions);
|
||||
if (!firstTiledOp)
|
||||
// Tile the fused loops in the last operation in the list.
|
||||
SmallVector<Value, 4> tileSizeVector =
|
||||
tilingOptions.tileSizeComputationFunction(builder, rootOp);
|
||||
Optional<TiledLinalgOp> tiledRootOp = tileRootOperation(
|
||||
builder, rootOp, tileSizeVector, tilingOptions, ret.fusedLoopDims);
|
||||
if (!tiledRootOp) {
|
||||
rootOp.emitError("failed to tile the fused loops");
|
||||
return llvm::None;
|
||||
ret.op = firstTiledOp->op;
|
||||
ret.fusedLoops.assign(firstTiledOp->loops.begin(), firstTiledOp->loops.end());
|
||||
|
||||
rewriter.setInsertionPoint(ret.op);
|
||||
// Fuse the operands.
|
||||
for (auto dependence : fusableDependences) {
|
||||
LinalgOp producerOp = cast<LinalgOp>(dependence.first);
|
||||
unsigned producerIdx =
|
||||
dependence.second.front().dependentOpView.operandIndex;
|
||||
unsigned consumerIdx =
|
||||
dependence.second.front().indexingOpView.operandIndex;
|
||||
LinalgOp fusedOp = fuse(rewriter, producerOp,
|
||||
producerOp.getOutputIndex(producerIdx).getValue(),
|
||||
ret.op, consumerIdx);
|
||||
ret.fusedProducers.push_back(fusedOp);
|
||||
ret.originalProducers.push_back(producerOp);
|
||||
}
|
||||
|
||||
if (!llvm::all_of(tileSizes, isZero)) {
|
||||
// Tile the remaining loops of the root operation.
|
||||
LinalgTilingOptions secondTilingOptions = tilingOptions;
|
||||
// The distribution is done only for the tile+fused loops.
|
||||
secondTilingOptions.distribution = llvm::None;
|
||||
secondTilingOptions.setTileSizes(tileSizes);
|
||||
Optional<TiledLinalgOp> secondTiledOp =
|
||||
tileLinalgOp(rewriter, ret.op, secondTilingOptions);
|
||||
if (!secondTiledOp)
|
||||
return llvm::None;
|
||||
ret.unfusedLoops.assign(secondTiledOp->loops.begin(),
|
||||
secondTiledOp->loops.end());
|
||||
rewriter.eraseOp(ret.op);
|
||||
ret.op = secondTiledOp->op;
|
||||
}
|
||||
ret.op = tiledRootOp->op;
|
||||
ret.fusedLoops.assign(tiledRootOp->loops.begin(), tiledRootOp->loops.end());
|
||||
|
||||
// Fuse the other operations into the fused inter-tile loops produced above.
|
||||
ret.fusedProducers = fuseOperations(builder, ret.op, ops.drop_back(),
|
||||
fusableDependences, ret.fusedLoopDims);
|
||||
return ret;
|
||||
}
|
||||
|
||||
Optional<TiledAndFusedLinalgOps>
|
||||
mlir::linalg::tileAndFuseLinalgOps(PatternRewriter &rewriter, LinalgOp op,
|
||||
mlir::linalg::tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef<LinalgOp> ops,
|
||||
const LinalgDependenceGraph &dependenceGraph,
|
||||
const LinalgTilingOptions &tilingOptions,
|
||||
const LinalgFusionOptions &fusionOptions) {
|
||||
const LinalgTilingOptions &tilingOptions) {
|
||||
switch (tilingOptions.loopType) {
|
||||
case LinalgTilingLoopType::Loops:
|
||||
return tileAndFuseLinalgOpsImpl<scf::ForOp>(rewriter, op, dependenceGraph,
|
||||
tilingOptions, fusionOptions);
|
||||
return tileAndFuseLinalgOpsImpl<scf::ForOp>(builder, ops, dependenceGraph,
|
||||
tilingOptions);
|
||||
case LinalgTilingLoopType::ParallelLoops:
|
||||
return tileAndFuseLinalgOpsImpl<scf::ParallelOp>(
|
||||
rewriter, op, dependenceGraph, tilingOptions, fusionOptions);
|
||||
builder, ops, dependenceGraph, tilingOptions);
|
||||
default:;
|
||||
}
|
||||
return llvm::None;
|
||||
|
|
|
@ -165,17 +165,69 @@ LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite(
|
|||
if (!linalgOp.hasBufferSemantics())
|
||||
return failure();
|
||||
|
||||
DenseSet<Operation *> producers;
|
||||
producers.insert(linalgOp);
|
||||
for (auto dependence : dependenceGraph.getDependentOperations(linalgOp)) {
|
||||
if (!fusionOptions.indicesToFuse.count(
|
||||
dependence.indexingOpView.operandIndex))
|
||||
continue;
|
||||
if (isa<LinalgOp>(dependence.dependentOpView.op))
|
||||
producers.insert(dependence.dependentOpView.op);
|
||||
}
|
||||
|
||||
SmallVector<LinalgOp, 1> fusionOps;
|
||||
for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie;
|
||||
++it) {
|
||||
auto producerLinalgOp = dyn_cast<LinalgOp>(&(*it));
|
||||
if (producerLinalgOp && producers.count(producerLinalgOp))
|
||||
fusionOps.push_back(producerLinalgOp);
|
||||
}
|
||||
fusionOps.push_back(linalgOp);
|
||||
|
||||
SmallVector<Value, 4> tileSizes =
|
||||
tilingOptions.tileSizeComputationFunction(rewriter, op);
|
||||
LinalgTilingOptions instanceTilingOptions = tilingOptions;
|
||||
instanceTilingOptions.setTileSizes(tileSizes);
|
||||
Optional<TiledAndFusedLinalgOps> tiledAndFusedOps = tileAndFuseLinalgOps(
|
||||
rewriter, op, dependenceGraph, tilingOptions, fusionOptions);
|
||||
rewriter, fusionOps, dependenceGraph, instanceTilingOptions);
|
||||
if (!tiledAndFusedOps)
|
||||
return failure();
|
||||
|
||||
// Tile the unfused loops;
|
||||
SmallVector<Value, 4> unfusedLoopTileSizes;
|
||||
Value zero = rewriter.create<ConstantIndexOp>(op->getLoc(), 0);
|
||||
for (auto tileSize : enumerate(tileSizes)) {
|
||||
if (tiledAndFusedOps->fusedLoopDims.count(tileSize.index()))
|
||||
unfusedLoopTileSizes.push_back(zero);
|
||||
else
|
||||
unfusedLoopTileSizes.push_back(tileSize.value());
|
||||
}
|
||||
// Tile the loop only if there is a non-zero tile size.
|
||||
if (unfusedLoopTileSizes.size() > linalgOp.getNumLoops())
|
||||
unfusedLoopTileSizes.resize(linalgOp.getNumLoops());
|
||||
if (llvm::any_of(unfusedLoopTileSizes, [](Value val) {
|
||||
if (auto cst = val.getDefiningOp<ConstantIndexOp>())
|
||||
return cst.getValue() != 0;
|
||||
return true;
|
||||
})) {
|
||||
LinalgTilingOptions unfusedTilingOptions = tilingOptions;
|
||||
unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes);
|
||||
Optional<TiledLinalgOp> unfusedTiledOp =
|
||||
tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions);
|
||||
if (!unfusedTiledOp)
|
||||
return failure();
|
||||
rewriter.eraseOp(tiledAndFusedOps->op);
|
||||
tiledAndFusedOps->op = unfusedTiledOp->op;
|
||||
}
|
||||
|
||||
marker.replaceLinalgMarker(rewriter, tiledAndFusedOps->op.getOperation());
|
||||
for (auto fusedOp : tiledAndFusedOps->fusedProducers) {
|
||||
fusedOpMarker.replaceLinalgMarker(rewriter, fusedOp.getOperation());
|
||||
}
|
||||
for (auto origProducerOp : tiledAndFusedOps->originalProducers)
|
||||
for (auto origProducerOp : ArrayRef<LinalgOp>(fusionOps).drop_back()) {
|
||||
originalOpMarker.replaceLinalgMarker(rewriter,
|
||||
origProducerOp.getOperation());
|
||||
}
|
||||
rewriter.updateRootInPlace(
|
||||
op, [&]() { originalOpMarker.replaceLinalgMarker(rewriter, op); });
|
||||
return success();
|
||||
|
|
|
@ -47,7 +47,9 @@ module {
|
|||
// CHECK: %[[TILE_N_2:.+]] = affine.min #[[MAP2]](%[[IV1]])[%[[N_2]]]
|
||||
// CHECK: %[[SV3:.+]] = subview %[[ARG2]][%[[IV0]], %[[IV1]]]
|
||||
// CHECK-SAME: [%[[TILE_M_2]], %[[TILE_N_2]]]
|
||||
// CHECK: linalg.fill(%[[SV3]], %[[CST]])
|
||||
// CHECK: %[[SV3_2:.+]] = subview %[[ARG2]][%[[IV0]], %[[IV1]]]
|
||||
// CHECK-SAME: [%[[TILE_M]], %[[TILE_N]]]
|
||||
// CHECK: linalg.fill(%[[SV3_2]], %[[CST]])
|
||||
// CHECK-SAME: __internal_linalg_transform__ = "after_basic_fusion_producer"
|
||||
// CHECK: scf.for %[[IV2:.+]] = %[[C0]] to %[[K]] step %[[C16]] {
|
||||
// CHECK: %[[TILE_K:.+]] = affine.min #[[MAP3]](%[[IV2]])[%[[K]]]
|
||||
|
@ -109,9 +111,12 @@ module {
|
|||
// CHECK: %[[TILE_N_2:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[N_2]]]
|
||||
// CHECK: %[[SV2:.+]] = subview %[[ARG3]][0, %[[IV0]]]
|
||||
// CHECK-SAME: [%[[M]], %[[TILE_N_2]]]
|
||||
// CHECK: %[[K_2:.+]] = dim %[[ARG1]], %[[C0]]
|
||||
// CHECK: %[[SV3:.+]] = subview %[[ARG1]][0, %[[IV0]]]
|
||||
// CHECK-SAME: [%[[K]], %[[TILE_N]]]
|
||||
// CHECK: linalg.copy(%[[SV3]], %[[SV1]])
|
||||
// CHECK-SAME: [%[[K_2]], %[[TILE_N]]]
|
||||
// CHECK: %[[SV3_2:.+]] = subview %[[ARG2]][0, %[[IV0]]]
|
||||
// CHECK-SAME: [%[[K_2]], %[[TILE_N]]]
|
||||
// CHECK: linalg.copy(%[[SV3]], %[[SV3_2]])
|
||||
// CHECK-SAME: __internal_linalg_transform__ = "after_rhs_fusion_producer"
|
||||
// CHECK-NOT: linalg.fill
|
||||
// CHECK-DAG: %[[M_2:.+]] = dim %[[ARG0]], %[[C0]]
|
||||
|
@ -186,11 +191,16 @@ module {
|
|||
// CHECK: %[[N:.+]] = dim %[[ARG3]], %[[C1]]
|
||||
// CHECK: %[[SV2:.+]] = subview %[[ARG3]][%[[IV0]], 0]
|
||||
// CHECK-SAME: [%[[TILE_M_2]], %[[N]]]
|
||||
// CHECK: %[[SV2_2:.+]] = subview %[[ARG3]][%[[IV0]], 0]
|
||||
// CHECK-SAME: [%[[TILE_M]], %[[N]]]
|
||||
// CHECK: %[[K_2:.+]] = dim %[[ARG0]], %[[C1]]
|
||||
// CHECK: %[[SV3:.+]] = subview %[[ARG0]][%[[IV0]], 0]
|
||||
// CHECK-SAME: [%[[TILE_M]], %[[K]]]
|
||||
// CHECK: linalg.copy(%[[SV3]], %[[SV1]])
|
||||
// CHECK-SAME: [%[[TILE_M]], %[[K_2]]]
|
||||
// CHECK: %[[SV3_2:.+]] = subview %[[ARG1]][%[[IV0]], 0]
|
||||
// CHECK-SAME: [%[[TILE_M]], %[[K_2]]]
|
||||
// CHECK: linalg.copy(%[[SV3]], %[[SV3_2]])
|
||||
// CHECK-SAME: __internal_linalg_transform__ = "after_two_operand_fusion_producer"
|
||||
// CHECK: linalg.fill(%[[SV2]], %[[CST]])
|
||||
// CHECK: linalg.fill(%[[SV2_2]], %[[CST]])
|
||||
// CHECK-SAME: __internal_linalg_transform__ = "after_two_operand_fusion_producer"
|
||||
// CHECK-DAG: %[[N_2:.+]] = dim %[[ARG2]], %[[C1]]
|
||||
// CHECK: scf.parallel (%[[IV1:.+]]) =
|
||||
|
@ -261,15 +271,18 @@ module {
|
|||
// CHECK: %[[N:.+]] = dim %[[ARG4]], %[[C1]]
|
||||
// CHECK: %[[SV2:.+]] = subview %[[ARG4]][%[[IV0]], 0]
|
||||
// CHECK-SAME: [%[[TILE_M_2]], %[[N]]]
|
||||
// CHECK: %[[K2_2:.+]] = dim %[[ARG1]], %[[C1]]
|
||||
// CHECK: %[[K1:.+]] = dim %[[ARG0]], %[[C1]]
|
||||
// CHECK: %[[SV3:.+]] = subview %[[ARG0]][%[[IV0]], 0]
|
||||
// CHECK-SAME: [%[[TILE_M]], %[[K1]]]
|
||||
// CHECK: %[[SV4:.+]] = subview %[[ARG1]][0, 0] [%[[K1]], %[[K2]]]
|
||||
// CHECK: %[[SV4:.+]] = subview %[[ARG1]][0, 0] [%[[K1]], %[[K2_2]]]
|
||||
// CHECK: %[[SV1_2:.+]] = subview %[[ARG2]][%[[IV0]], 0]
|
||||
// CHECK-SAME: [%[[TILE_M]], %[[K2_2]]]
|
||||
// CHECK: linalg.matmul
|
||||
// CHECK-SAME: __internal_linalg_transform__ = "after_lhs_fusion_producer"
|
||||
// CHECK-SAME: ins(%[[SV3]], %[[SV4]]
|
||||
// CHECK-SAME: : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32, #[[MAP1]]>)
|
||||
// CHECK-SAME: outs(%[[SV1]] : memref<?x?xf32, #[[MAP1]]>)
|
||||
// CHECK-SAME: outs(%[[SV1_2]] : memref<?x?xf32, #[[MAP1]]>)
|
||||
// CHECK-DAG: %[[N_2:.+]] = dim %[[ARG3]], %[[C1]]
|
||||
// CHECK: scf.parallel (%[[IV1:.+]]) =
|
||||
// CHECK-SAME: (%[[C0]]) to (%[[N_2]]) step (%[[C64]]) {
|
||||
|
@ -413,3 +426,30 @@ module {
|
|||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module {
|
||||
func @basic_conv_fusion(%arg0: memref<?x?x?x?xf32>, %arg1: memref<?x?x?x?xf32>,
|
||||
%arg2: memref<?x?x?x?xf32>) {
|
||||
%cst = constant 0.000000e+00 : f32
|
||||
linalg.fill(%arg2, %cst) : memref<?x?x?x?xf32>, f32
|
||||
linalg.conv(%arg0, %arg1, %arg2) {
|
||||
dilations = [1, 1], strides = [1, 1],
|
||||
__internal_linalg_transform__ = "basic_fusion"} :
|
||||
memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>
|
||||
return
|
||||
}
|
||||
}
|
||||
// CHECK: func @basic_conv_fusion
|
||||
// CHECK: linalg.fill
|
||||
// CHECK-SAME: __internal_linalg_transform__ = "after_basic_fusion_original"
|
||||
// CHECK: scf.parallel (%{{.+}}, %{{.+}}, %{{.+}})
|
||||
// CHECK-SAME: {
|
||||
// CHECK: linalg.fill
|
||||
// CHECK-SAME: __internal_linalg_transform__ = "after_basic_fusion_producer"
|
||||
// CHECK: linalg.conv
|
||||
// CHECK-SAME: __internal_linalg_transform__ = "after_basic_fusion"
|
||||
// CHECK: }
|
||||
// CHECK: linalg.conv
|
||||
// CHECK-SAME: __internal_linalg_transform__ = "after_basic_fusion_original"
|
||||
|
|
|
@ -0,0 +1,133 @@
|
|||
// RUN: mlir-opt -pass-pipeline="func(test-linalg-tile-and-fuse{tile-sizes=16,32,64}),canonicalize,cse" -split-input-file %s | FileCheck %s
|
||||
|
||||
module {
|
||||
func @three_op_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
|
||||
%arg2: memref<?xf32>, %arg3 : memref<?x?xf32>) {
|
||||
%cst = constant 0.000000e+00 : f32
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
%d0 = dim %arg0, %c0 : memref<?x?xf32>
|
||||
%d1 = dim %arg1, %c1 : memref<?x?xf32>
|
||||
%0 = alloc(%d0, %d1) : memref<?x?xf32>
|
||||
linalg.fill(%0, %cst) : memref<?x?xf32>, f32
|
||||
linalg.matmul ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
|
||||
outs(%0 : memref<?x?xf32>)
|
||||
linalg.generic
|
||||
{indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
|
||||
affine_map<(d0, d1) -> (d1)>,
|
||||
affine_map<(d0, d1) -> (d0, d1)>],
|
||||
iterator_types = ["parallel", "parallel"]}
|
||||
ins(%0, %arg2 : memref<?x?xf32>, memref<?xf32>)
|
||||
outs(%arg3 : memref<?x?xf32>) {
|
||||
^bb0(%arg4 : f32, %arg5 : f32, %arg6 : f32) :
|
||||
%5 = addf %arg4, %arg5 : f32
|
||||
linalg.yield %5 : f32
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
|
||||
// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (d0 + s0)>
|
||||
// CHECK: func @three_op_fusion
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32>
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32>
|
||||
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref<?xf32>
|
||||
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: memref<?x?xf32>
|
||||
// CHECK: %[[TEMP:.+]] = alloc(%{{.*}}, %{{.*}}) : memref<?x?xf32>
|
||||
// CHECK: scf.parallel (%[[IV0:.+]], %[[IV1:.+]]) = {{.*}} {
|
||||
// CHECK-DAG: %[[SV_TEMP:.+]] = subview %[[TEMP]][%[[IV0]], %[[IV1]]]
|
||||
// CHECK-DAG: %[[SV_ARG2:.+]] = subview %[[ARG2]][%[[IV1]]]
|
||||
// CHECK-DAG: %[[SV_ARG3:.+]] = subview %[[ARG3]][%[[IV0]], %[[IV1]]]
|
||||
// CHECK-DAG: %[[SV_ARG0:.+]] = subview %[[ARG0]][%[[IV0]], 0]
|
||||
// CHECK-DAG: %[[SV_ARG1:.+]] = subview %[[ARG1]][0, %[[IV1]]]
|
||||
// CHECK: linalg.fill(%[[SV_TEMP]], %{{.+}})
|
||||
// CHECK: linalg.matmul
|
||||
// CHECK-SAME: ins(%[[SV_ARG0]], %[[SV_ARG1]]
|
||||
// CHECK-SAME: : memref<?x?xf32, #[[MAP2]]>, memref<?x?xf32, #[[MAP2]]>)
|
||||
// CHECK-SAME: outs(%[[SV_TEMP]] : memref<?x?xf32, #[[MAP2]]>)
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-SAME: ins(%[[SV_TEMP]], %[[SV_ARG2]]
|
||||
// CHECK-SAME: : memref<?x?xf32, #[[MAP2]]>, memref<?xf32, #[[MAP3]]>)
|
||||
// CHECK-SAME: outs(%[[SV_ARG3]] : memref<?x?xf32, #[[MAP2]]>)
|
||||
// CHECK: scf.yield
|
||||
// CHECK: }
|
||||
|
||||
// -----
|
||||
|
||||
module {
|
||||
func @sequence_of_matmul(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
|
||||
%arg2: memref<?x?xf32>, %arg3: memref<?x?xf32>,
|
||||
%arg4: memref<?x?xf32>) {
|
||||
%cst = constant 0.000000e+00 : f32
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
%m = dim %arg0, %c0 : memref<?x?xf32>
|
||||
%n1 = dim %arg1, %c1 : memref<?x?xf32>
|
||||
%n2 = dim %arg2, %c1 : memref<?x?xf32>
|
||||
%n3 = dim %arg3, %c1 : memref<?x?xf32>
|
||||
%0 = alloc(%m, %n1) : memref<?x?xf32>
|
||||
%1 = alloc(%m, %n2) : memref<?x?xf32>
|
||||
linalg.fill(%0, %cst) : memref<?x?xf32>, f32
|
||||
linalg.matmul ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
|
||||
outs(%0 : memref<?x?xf32>)
|
||||
linalg.fill(%1, %cst) : memref<?x?xf32>, f32
|
||||
linalg.matmul ins(%0, %arg2 : memref<?x?xf32>, memref<?x?xf32>)
|
||||
outs(%1 : memref<?x?xf32>)
|
||||
linalg.fill(%arg4, %cst) : memref<?x?xf32>, f32
|
||||
linalg.matmul ins(%1, %arg3 : memref<?x?xf32>, memref<?x?xf32>)
|
||||
outs(%arg4 : memref<?x?xf32>)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)>
|
||||
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
|
||||
// CHECK: func @sequence_of_matmul
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32>
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32>
|
||||
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref<?x?xf32>
|
||||
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: memref<?x?xf32>
|
||||
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: memref<?x?xf32>
|
||||
// CHECK-DAG: %[[C0:.+]] = constant 0 : index
|
||||
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
|
||||
// CHECK-DAG: %[[C16:.+]] = constant 16 : index
|
||||
// CHECK-DAG: %[[M:.+]] = dim %[[ARG0]], %[[C0]]
|
||||
// CHECK-DAG: %[[N1:.+]] = dim %[[ARG1]], %[[C1]]
|
||||
// CHECK-DAG: %[[N2:.+]] = dim %[[ARG2]], %[[C1]]
|
||||
// CHECK: %[[ALLOC1:.+]] = alloc(%[[M]], %[[N1]])
|
||||
// CHECK: %[[ALLOC2:.+]] = alloc(%[[M]], %[[N2]])
|
||||
// CHECK: scf.parallel (%[[IV0:.+]]) = (%[[C0]]) to (%[[M]])
|
||||
// CHECK-SAME: step (%[[C16]]) {
|
||||
// CHECK: %[[TILE_M:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M]]]
|
||||
// CHECK: %[[SV_ALLOC2:.+]] = subview %[[ALLOC2]][%[[IV0]], 0]
|
||||
// CHECK-SAME: [%[[TILE_M]], %[[N2]]]
|
||||
// CHECK: %[[M_2:.+]] = dim %[[ARG4]], %[[C0]]
|
||||
// CHECK: %[[TILE_M_2:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M_2]]]
|
||||
// CHECK: %[[N3:.+]] = dim %[[ARG4]], %[[C1]]
|
||||
// CHECK: %[[SV_ARG4:.+]] = subview %[[ARG4]][%[[IV0]], 0]
|
||||
// CHECK-SAME: [%[[TILE_M_2]], %[[N3]]]
|
||||
// CHECK: %[[SV_ARG4_2:.+]] = subview %[[ARG4]][%[[IV0]], 0]
|
||||
// CHECK-SAME: [%[[TILE_M]], %[[N3]]]
|
||||
// CHECK: %[[SV_ALLOC1:.+]] = subview %[[ALLOC1]][%[[IV0]], 0]
|
||||
// CHECK-SAME: [%[[TILE_M]], %[[N1]]]
|
||||
// CHECK: %[[SV_ARG2:.+]] = subview %[[ARG2]][0, 0] [%[[N1]], %[[N2]]]
|
||||
// CHECK: %[[N0:.+]] = dim %[[ARG0]], %[[C1]]
|
||||
// CHECK: %[[SV_ARG0:.+]] = subview %[[ARG0]][%[[IV0]], 0]
|
||||
// CHECK-SAME: [%[[TILE_M:.+]], %[[N0]]]
|
||||
// CHECK: %[[SV_ARG1:.+]] = subview %[[ARG1]][0, 0] [%[[N0]], %[[N1]]]
|
||||
// CHECK: linalg.fill(%[[SV_ALLOC1]], %{{.+}})
|
||||
// CHECK: linalg.matmul ins(%[[SV_ARG0]], %[[SV_ARG1]]
|
||||
// CHECK-SAME: : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32, #[[MAP1]]>)
|
||||
// CHECK-SAME: outs(%[[SV_ALLOC1]] : memref<?x?xf32, #[[MAP1]]>)
|
||||
// CHECK: linalg.fill(%[[SV_ALLOC2]], %{{.+}})
|
||||
// CHECK: linalg.matmul ins(%[[SV_ALLOC1]], %[[SV_ARG2]]
|
||||
// CHECK-SAME: : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32, #[[MAP1]]>)
|
||||
// CHECK-SAME: outs(%[[SV_ALLOC2]] : memref<?x?xf32, #[[MAP1]]>)
|
||||
// CHECK: linalg.fill(%[[SV_ARG4_2]], %{{.+}})
|
||||
// CHECK: linalg.matmul ins(%[[SV_ALLOC2]], %[[ARG3]]
|
||||
// CHECK-SAME: : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32>)
|
||||
// CHECK-SAME: outs(%[[SV_ARG4]] : memref<?x?xf32, #[[MAP1]]>)
|
||||
// CHECK: scf.yield
|
||||
// CHECK: }
|
||||
|
|
@ -38,7 +38,8 @@ struct TestLinalgFusionTransforms
|
|||
static void fillFusionPatterns(MLIRContext *context,
|
||||
const LinalgDependenceGraph &dependenceGraph,
|
||||
OwningRewritePatternList &patterns) {
|
||||
patterns.insert<LinalgTileAndFusePattern<MatmulOp>>(
|
||||
patterns.insert<LinalgTileAndFusePattern<MatmulOp>,
|
||||
LinalgTileAndFusePattern<ConvOp>>(
|
||||
context, dependenceGraph,
|
||||
LinalgTilingOptions()
|
||||
.setTileSizes({32, 64, 16})
|
||||
|
@ -197,6 +198,44 @@ struct TestLinalgGreedyFusion
|
|||
}
|
||||
}
|
||||
};
|
||||
|
||||
/// Pass to test tile and fuse of sequence of operations. Intended only for
|
||||
/// testing.
|
||||
struct TestLinalgTileAndFuseSequencePass
|
||||
: public PassWrapper<TestLinalgTileAndFuseSequencePass, FunctionPass> {
|
||||
TestLinalgTileAndFuseSequencePass() = default;
|
||||
TestLinalgTileAndFuseSequencePass(
|
||||
const TestLinalgTileAndFuseSequencePass &pass){};
|
||||
|
||||
ListOption<int64_t> tileSizes{
|
||||
*this, "tile-sizes", llvm::cl::desc("Tile sizes to use for ops"),
|
||||
llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
|
||||
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<AffineDialect, linalg::LinalgDialect, scf::SCFDialect>();
|
||||
}
|
||||
|
||||
void runOnFunction() override {
|
||||
FuncOp funcOp = getOperation();
|
||||
auto &blocks = funcOp.getBody().getBlocks();
|
||||
if (!llvm::hasSingleElement(blocks)) {
|
||||
return;
|
||||
}
|
||||
SmallVector<LinalgOp, 2> linalgOps =
|
||||
llvm::to_vector<2>(blocks.front().getOps<LinalgOp>());
|
||||
Aliases aliases;
|
||||
LinalgDependenceGraph dependenceGraph(aliases, linalgOps);
|
||||
OpBuilder builder(funcOp.getContext());
|
||||
Optional<TiledAndFusedLinalgOps> tileAndFuseOps = tileAndFuseLinalgOps(
|
||||
builder, linalgOps, dependenceGraph,
|
||||
LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(
|
||||
LinalgTilingLoopType::ParallelLoops));
|
||||
if (!tileAndFuseOps)
|
||||
return signalPassFailure();
|
||||
for (auto op : linalgOps)
|
||||
op.erase();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace mlir {
|
||||
|
@ -211,5 +250,12 @@ void registerTestLinalgGreedyFusion() {
|
|||
"test-linalg-greedy-fusion",
|
||||
"Test Linalg fusion by applying a greedy test transformation.");
|
||||
}
|
||||
void registerTestLinalgTileAndFuseSequencePass() {
|
||||
PassRegistration<TestLinalgTileAndFuseSequencePass>
|
||||
testTileAndFuseSequencePass(
|
||||
"test-linalg-tile-and-fuse",
|
||||
"Test Linalg tiling and fusion of a sequence of Linalg operations.");
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace mlir
|
||||
|
|
|
@ -74,6 +74,7 @@ void registerTestLinalgCodegenStrategy();
|
|||
void registerTestLinalgFusionTransforms();
|
||||
void registerTestLinalgGreedyFusion();
|
||||
void registerTestLinalgHoisting();
|
||||
void registerTestLinalgTileAndFuseSequencePass();
|
||||
void registerTestLinalgTransforms();
|
||||
void registerTestLivenessPass();
|
||||
void registerTestLoopFusion();
|
||||
|
@ -141,6 +142,7 @@ void registerTestPasses() {
|
|||
test::registerTestLinalgFusionTransforms();
|
||||
test::registerTestLinalgGreedyFusion();
|
||||
test::registerTestLinalgHoisting();
|
||||
test::registerTestLinalgTileAndFuseSequencePass();
|
||||
test::registerTestLinalgTransforms();
|
||||
test::registerTestLivenessPass();
|
||||
test::registerTestLoopFusion();
|
||||
|
|
Loading…
Reference in New Issue