[mlir][Linalg] Improve the logic to perform tile and fuse with better dependence tracking.

This change does two main things
1) An operation might have multiple dependences to the same
   producer. Not tracking them correctly can result in incorrect code
   generation with fusion. To rectify this the dependence tracking
   needs to also have the operand number in the consumer.
2) Improve the logic used to find the fused loops making it easier to
   follow. The only constraint for fusion is that linalg ops (on
   buffers) have update semantics for the result. Fusion should be
   such that only one iteration of the fused loop (which is also a
   tiled loop) must touch only one (disjoint) tile of the output. This
   could be relaxed by allowing for recomputation that is the default
   when oeprands are tensors, or can be made legal with promotion of
   the fused view (in future).

Differential Revision: https://reviews.llvm.org/D90579
This commit is contained in:
MaheshRavishankar 2020-11-12 00:24:36 -08:00
parent ad376657c1
commit 5ca20851e4
9 changed files with 518 additions and 204 deletions

View File

@ -45,7 +45,7 @@ class LinalgDependenceGraph {
public:
struct LinalgOpView {
Operation *op;
Value view;
unsigned operandIndex;
};
struct LinalgDependenceGraphElem {
// dependentOpView may be either:
@ -55,7 +55,7 @@ public:
// View in the op that is used to index in the graph:
// 1. src in the case of dependencesFromDstGraphs.
// 2. dst in the case of dependencesIntoGraphs.
Value indexingView;
LinalgOpView indexingOpView;
};
using LinalgDependences = SmallVector<LinalgDependenceGraphElem, 8>;
using DependenceGraph = DenseMap<Operation *, LinalgDependences>;

View File

@ -555,7 +555,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
>,
InterfaceMethod<
/*desc=*/[{
Return the position of the shaped operand in the operand list.
Return the first position of the shaped operand in the operand list.
}],
/*retTy=*/"Optional<unsigned>",
/*methodName=*/"getIndexOfShapedOperand",
@ -573,6 +573,67 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
return llvm::None;
}]
>,
InterfaceMethod<
/*desc=*/[{
Returns the operand index given the input index. Returns None
of the input index is invalid.
}],
/*retTy=*/"Optional<unsigned>",
/*methodName=*/"getOperandIndexForInputIndex",
/*args=*/(ins "unsigned":$input_index),
/*methodBody=*/"",
/*defaultImplementation=*/[{
if (input_index >= $_op.getNumInputs())
return llvm::None;
return input_index;
}]
>,
InterfaceMethod<
/*desc=*/[{
Returns the operand index given the output index. Returns None
of the output index is invalid.
}],
/*retTy=*/"Optional<unsigned>",
/*methodName=*/"getOperandIndexForOutputIndex",
/*args=*/(ins "unsigned":$output_index),
/*methodBody=*/"",
/*defaultImplementation=*/[{
if (output_index >= $_op.getNumOutputs())
return llvm::None;
return output_index + $_op.getNumInputs();
}]
>,
InterfaceMethod<
/*desc=*/[{
Returns the input index given the operand index. Return None
if the operand index doesnt corresponding to an input.
}],
/*retTy=*/"Optional<unsigned>",
/*methodName=*/"getInputIndex",
/*args=*/(ins "unsigned":$operand_index),
/*methodBody=*/"",
/*defaultImplementation=*/[{
if (operand_index >= $_op.getNumInputs())
return llvm::None;
return operand_index;
}]
>,
InterfaceMethod<
/*desc=*/[{
Returns the output index given the operand index. Return None
if the operand index doesnt corresponding to an output.
}],
/*retTy=*/"Optional<unsigned>",
/*methodName=*/"getOutputIndex",
/*args=*/(ins "unsigned":$operand_index),
/*methodBody=*/"",
/*defaultImplementation=*/[{
if (operand_index < $_op.getNumInputs() ||
operand_index >= $_op.getNumInputs() + $_op.getNumOutputs())
return llvm::None;
return operand_index - $_op.getNumInputs();
}]
>,
//===------------------------------------------------------------------===//
// Other interface methods.

View File

@ -15,6 +15,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/Bufferize.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SmallSet.h"
namespace mlir {
class BufferizeTypeConverter;
@ -429,12 +430,10 @@ struct LinalgTilingPattern : public LinalgBaseTilingPattern {
};
struct LinalgFusionOptions {
/// Optional list of operands indices to use for fusion. When unspecified,
/// only one fusion is done, i.e., the pattern returns after the first fusion.
Optional<DenseSet<unsigned>> indicesToFuse = None;
/// List of operands indices to use for fusion.
llvm::SmallSet<unsigned, 1> indicesToFuse = {};
LinalgFusionOptions &setIndicesToFuse(ArrayRef<int64_t> operands) {
indicesToFuse = DenseSet<unsigned>();
indicesToFuse->insert(operands.begin(), operands.end());
indicesToFuse.insert(operands.begin(), operands.end());
return *this;
}
};

View File

@ -323,6 +323,9 @@ AffineMap inversePermutation(AffineMap map);
/// ```
AffineMap concatAffineMaps(ArrayRef<AffineMap> maps);
AffineMap getProjectedMap(AffineMap map,
ArrayRef<unsigned> projectedDimensions);
inline raw_ostream &operator<<(raw_ostream &os, AffineMap map) {
map.print(os);
return os;

View File

@ -108,12 +108,14 @@ LinalgDependenceGraph::LinalgDependenceGraph(Aliases &aliases,
void LinalgDependenceGraph::addDependenceElem(DependenceType dt,
LinalgOpView indexingOpView,
LinalgOpView dependentOpView) {
LLVM_DEBUG(dbgs() << "\nAdd dep type " << getDependenceTypeStr(dt) << ":\t"
<< *indexingOpView.op << " -> " << *dependentOpView.op);
LLVM_DEBUG(dbgs() << "\nAdd dep type " << getDependenceTypeStr(dt) << ":\t ("
<< *indexingOpView.op << ", " << indexingOpView.operandIndex
<< ") -> \n\t\t(" << *dependentOpView.op << ", "
<< dependentOpView.operandIndex << ")");
dependencesFromGraphs[dt][indexingOpView.op].push_back(
LinalgDependenceGraphElem{dependentOpView, indexingOpView.view});
LinalgDependenceGraphElem{dependentOpView, indexingOpView});
dependencesIntoGraphs[dt][dependentOpView.op].push_back(
LinalgDependenceGraphElem{indexingOpView, dependentOpView.view});
LinalgDependenceGraphElem{indexingOpView, dependentOpView});
}
LinalgDependenceGraph::dependence_range
@ -147,39 +149,55 @@ LinalgDependenceGraph::getDependencesInto(
}
void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) {
for (auto srcView : src.getOutputBuffers()) { // W
for (auto srcView : llvm::enumerate(src.getOutputBuffers())) { // W
unsigned srcIndex =
src.getOperandIndexForOutputIndex(srcView.index()).getValue();
// RAW graph
for (auto dstView : dst.getInputBuffers()) { // R
if (aliases.alias(srcView, dstView)) { // if alias, fill RAW
for (auto dstView : llvm::enumerate(dst.getInputBuffers())) { // R
if (aliases.alias(srcView.value(),
dstView.value())) { // if alias, fill RAW
unsigned dstIndex =
dst.getOperandIndexForInputIndex(dstView.index()).getValue();
addDependenceElem(DependenceType::RAW,
LinalgOpView{src.getOperation(), srcView},
LinalgOpView{dst.getOperation(), dstView});
LinalgOpView{src.getOperation(), srcIndex},
LinalgOpView{dst.getOperation(), dstIndex});
}
}
// WAW graph
for (auto dstView : dst.getOutputBuffers()) { // W
if (aliases.alias(srcView, dstView)) { // if alias, fill WAW
for (auto dstView : llvm::enumerate(dst.getOutputBuffers())) { // W
if (aliases.alias(srcView.value(),
dstView.value())) { // if alias, fill WAW
unsigned dstIndex =
dst.getOperandIndexForOutputIndex(dstView.index()).getValue();
addDependenceElem(DependenceType::WAW,
LinalgOpView{src.getOperation(), srcView},
LinalgOpView{dst.getOperation(), dstView});
LinalgOpView{src.getOperation(), srcIndex},
LinalgOpView{dst.getOperation(), dstIndex});
}
}
}
for (auto srcView : src.getInputBuffers()) { // R
for (auto srcView : llvm::enumerate(src.getInputBuffers())) { // R
unsigned srcIndex =
src.getOperandIndexForInputIndex(srcView.index()).getValue();
// RAR graph
for (auto dstView : dst.getInputBuffers()) { // R
if (aliases.alias(srcView, dstView)) { // if alias, fill RAR
for (auto dstView : llvm::enumerate(dst.getInputBuffers())) { // R
if (aliases.alias(srcView.value(),
dstView.value())) { // if alias, fill RAR
unsigned dstIndex =
dst.getOperandIndexForInputIndex(dstView.index()).getValue();
addDependenceElem(DependenceType::RAR,
LinalgOpView{src.getOperation(), srcView},
LinalgOpView{dst.getOperation(), dstView});
LinalgOpView{src.getOperation(), srcIndex},
LinalgOpView{dst.getOperation(), dstIndex});
}
}
// WAR graph
for (auto dstView : dst.getOutputBuffers()) { // W
if (aliases.alias(srcView, dstView)) { // if alias, fill WAR
for (auto dstView : llvm::enumerate(dst.getOutputBuffers())) { // W
if (aliases.alias(srcView.value(),
dstView.value())) { // if alias, fill WAR
unsigned dstIndex =
dst.getOperandIndexForOutputIndex(dstView.index()).getValue();
addDependenceElem(DependenceType::WAR,
LinalgOpView{src.getOperation(), srcView},
LinalgOpView{dst.getOperation(), dstView});
LinalgOpView{src.getOperation(), srcIndex},
LinalgOpView{dst.getOperation(), dstIndex});
}
}
}
@ -227,12 +245,16 @@ LinalgDependenceGraph::findOperationsWithCoveringDependences(
// Skip if not interleaved.
if (interimPos >= dstPos || interimPos <= srcPos)
continue;
if (view && !aliases.alias(view, dependence.indexingView))
linalg::LinalgOp consumer =
cast<linalg::LinalgOp>(dependence.indexingOpView.op);
Value consumerView =
consumer.getShapedOperand(dependence.indexingOpView.operandIndex);
if (view && !aliases.alias(view, consumerView))
continue;
auto *op = dependence.dependentOpView.op;
LLVM_DEBUG(dbgs() << "\n***Found covering dependence of type "
<< getDependenceTypeStr(dt) << ": " << *src << " -> "
<< *op << " on " << dependence.indexingView);
<< *op << " on " << consumerView);
res.push_back(op);
}
}

View File

@ -24,10 +24,12 @@
#include "mlir/IR/Dominance.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include <set>
#define DEBUG_TYPE "linalg-fusion"
using namespace mlir;
@ -95,8 +97,8 @@ static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op,
for (auto en : llvm::enumerate(op.getShapedOperands())) {
unsigned shapedOperandIdx = en.index();
AffineMap map = op.getIndexingMap(shapedOperandIdx);
LLVM_DEBUG(dbgs() << "shapedOperandIdx: " << shapedOperandIdx
<< " with indexingMap: " << map << "\n");
LLVM_DEBUG(llvm::dbgs() << "shapedOperandIdx: " << shapedOperandIdx
<< " with indexingMap: " << map << "\n");
SmallVector<Value, 4> offsets, sizes, strides;
inferShapeComponents(map, loopRanges, offsets, sizes, strides);
Value shape = en.value();
@ -169,16 +171,18 @@ static ShapeDimension getShapeDefiningLoopRange(LinalgOp op,
for (auto en : llvm::enumerate(ios)) {
unsigned idx = en.index();
auto map = maps[idx].cast<AffineMapAttr>().getValue();
LLVM_DEBUG(dbgs() << "getShapeDefiningLoopRange I/O idx: " << idx << "\n");
LLVM_DEBUG(dbgs() << "getShapeDefiningLoopRange map: " << map << "\n");
LLVM_DEBUG(llvm::dbgs()
<< "getShapeDefiningLoopRange I/O idx: " << idx << "\n");
LLVM_DEBUG(llvm::dbgs()
<< "getShapeDefiningLoopRange map: " << map << "\n");
Value shape = en.value();
SmallVector<Value, 8> shapeRanges(map.getNumResults(), nullptr);
for (auto en2 : llvm::enumerate(map.getResults())) {
if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) {
LLVM_DEBUG(dbgs() << "getShapeDefiningLoopRange loopDepth: "
<< loopDepth << "\n");
LLVM_DEBUG(dbgs() << "getShapeDefiningLoopRange shape: " << shape
<< "\n");
LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange loopDepth: "
<< loopDepth << "\n");
LLVM_DEBUG(llvm::dbgs()
<< "getShapeDefiningLoopRange shape: " << shape << "\n");
return ShapeDimension{shape, static_cast<unsigned>(en2.index())};
}
}
@ -209,8 +213,8 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx,
// dimension.
// TODO: extend this with range inference.
AffineMap producerMap = producer.getOutputIndexingMap(producerIdx);
LLVM_DEBUG(dbgs() << "Producer Idx: " << producerIdx
<< ", producer map: " << producerMap << "\n");
LLVM_DEBUG(llvm::dbgs() << "Producer Idx: " << producerIdx
<< ", producer map: " << producerMap << "\n");
unsigned nPar = producer.getNumParallelLoops();
unsigned nRed = producer.getNumReductionLoops();
@ -258,7 +262,7 @@ static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView,
assert(consumer.hasBufferSemantics() &&
"expected linalg op with buffer semantics");
if (producer.getNumOutputs() != 1) {
LLVM_DEBUG(dbgs() << "\nNot structurally fusable (multi-output)");
LLVM_DEBUG(llvm::dbgs() << "\nNot structurally fusable (multi-output)");
return false;
}
// Only fuse when the producer block dominates.
@ -266,7 +270,7 @@ static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView,
if (!dom.dominates(producer.getOperation()->getBlock(),
consumer.getOperation()->getBlock())) {
LLVM_DEBUG(
dbgs()
llvm::dbgs()
<< "\nNot structurally fusable (producer block does not dominate)");
return false;
}
@ -284,14 +288,14 @@ bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph,
// Make some simple structural checks that alleviate the need for more
// complex analyses.
if (!isStructurallyFusableProducer(producer, consumedView, consumer)) {
LLVM_DEBUG(dbgs() << "\n***Not static last write due to structure:\t"
<< *producer.getOperation());
LLVM_DEBUG(llvm::dbgs() << "\n***Not static last write due to structure:\t"
<< *producer.getOperation());
return false;
}
// Check for any interleaved write to consumedView.
if (!graph.findCoveringWrites(producer, consumer, consumedView).empty()) {
LLVM_DEBUG(dbgs() << "\n***Not fusable due to interleaved write:\t"
<< *producer.getOperation());
LLVM_DEBUG(llvm::dbgs() << "\n***Not fusable due to interleaved write:\t"
<< *producer.getOperation());
return false;
}
return true;
@ -309,8 +313,9 @@ bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph,
// Check for any fusion-preventing dependence to any shape read/written that
// would violate dependences.
if (!graph.findCoveringDependences(producer, consumer).empty()) {
LLVM_DEBUG(dbgs() << "\n***Not fusable due to an interleaved dependence:\t"
<< *producer.getOperation());
LLVM_DEBUG(llvm::dbgs()
<< "\n***Not fusable due to an interleaved dependence:\t"
<< *producer.getOperation());
return false;
}
if (auto convOp = dyn_cast<linalg::ConvOp>(producer.getOperation())) {
@ -360,26 +365,33 @@ findFusableProducer(LinalgOp consumer, unsigned consumerIdx,
LinalgDependenceGraph::DependenceType::RAW,
LinalgDependenceGraph::DependenceType::WAW,
}) {
for (auto dependence :
dependenceGraph.getDependencesInto(consumer, depType)) {
for (auto dependence : llvm::make_filter_range(
dependenceGraph.getDependencesInto(consumer, depType),
[consumerIdx](
LinalgDependenceGraph::LinalgDependenceGraphElem elem) {
return elem.indexingOpView.operandIndex == consumerIdx;
})) {
auto producer = cast<LinalgOp>(dependence.dependentOpView.op);
// Check that the dependence is indeed on the input `consumerIdx` view.
auto consumedView = dependence.indexingView;
auto consumedView =
consumer.getBuffer(dependence.indexingOpView.operandIndex);
if (!isSameSubView(consumer.getBuffer(consumerIdx), consumedView))
continue;
// Consumer consumes this view, `isStructurallyFusableProducer` also
// checks whether it is a strict subview of the producer view.
auto producedView = dependence.dependentOpView.view;
auto producerIdx =
producer.getIndexOfOutputBuffer(producedView).getValue();
// `consumerIdx` and `producerIdx` exist by construction.
LLVM_DEBUG(dbgs() << "\n"
<< LinalgDependenceGraph::getDependenceTypeStr(depType)
<< "producer: " << *producer.getOperation() << " view: "
<< producedView << " output index: " << producerIdx);
(void)producerIdx;
auto producedView =
producer.getBuffer(dependence.dependentOpView.operandIndex);
LLVM_DEBUG(llvm::dbgs()
<< "\n"
<< LinalgDependenceGraph::getDependenceTypeStr(depType)
<< "producer: " << *producer.getOperation()
<< " view: " << producedView << " output index: "
<< dependence.dependentOpView.operandIndex -
producer.getNumInputs()
<< "\n");
(void)producedView;
// Simple fusability checks.
if (!isFusableInto(dependenceGraph, consumer, consumedView, producer))
@ -406,15 +418,16 @@ mlir::linalg::fuseProducerOfBuffer(OpBuilder &b, LinalgOp consumer,
producerOp.getOperation()->getBlock())
return {};
Value producerView = fusableDependence->dependentOpView.view;
Value consumerView = fusableDependence->indexingView;
unsigned producerIdx = fusableDependence->dependentOpView.operandIndex -
producerOp.getNumInputs();
Value consumerView = consumer.getShapedOperand(consumerIdx);
// Must be a subview or a slice to guarantee there are loops we can fuse
// into.
auto subView = consumerView.getDefiningOp<SubViewOp>();
auto slice = consumerView.getDefiningOp<SliceOp>();
if (!subView && !slice) {
LLVM_DEBUG(dbgs() << "\nNot fusable (not a subview or slice)");
LLVM_DEBUG(llvm::dbgs() << "\nNot fusable (not a subview or slice)");
return {};
}
@ -422,11 +435,7 @@ mlir::linalg::fuseProducerOfBuffer(OpBuilder &b, LinalgOp consumer,
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(consumer.getOperation());
ScopedContext scope(b, consumer.getLoc());
LLVM_DEBUG(dbgs() << "Fuse into consumer: " << *consumer << "\n");
Optional<unsigned> producerIdxOpt =
producerOp.getIndexOfOutputBuffer(producerView);
assert(producerIdxOpt.hasValue() && "incorrect operand index");
unsigned producerIdx = producerIdxOpt.getValue();
LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " << *consumer << "\n");
auto fusedProducer = fuse(b, producerOp, producerIdx, consumer, consumerIdx);
return FusionInfo{producerOp, fusedProducer};
@ -470,7 +479,7 @@ Optional<FusionInfo> mlir::linalg::fuseProducerOfTensor(OpBuilder &b,
// Must be a subtensor to guarantee there are loops we can fuse into.
auto subTensor = inputTensor.getDefiningOp<SubTensorOp>();
if (!subTensor || !producerOp) {
LLVM_DEBUG(dbgs() << "\nNot fusable (not a subtensor)");
LLVM_DEBUG(llvm::dbgs() << "\nNot fusable (not a subtensor)");
return {};
}
@ -483,7 +492,7 @@ Optional<FusionInfo> mlir::linalg::fuseProducerOfTensor(OpBuilder &b,
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(consumer.getOperation());
ScopedContext scope(b, consumer.getLoc());
LLVM_DEBUG(dbgs() << "Fuse into consumer: " << *consumer << "\n");
LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " << *consumer << "\n");
LinalgOp fusedProducer =
fuse(b, producerOp, producerIdx, consumer, consumerIdx);
@ -501,6 +510,21 @@ Optional<FusionInfo> mlir::linalg::fuseProducerOfTensor(OpBuilder &b,
return FusionInfo{producerOp, fusedProducer};
}
/// Prune all dimensions that are of reduction iterator type from `map`.
static AffineMap pruneReductionDimsFromMap(ArrayRef<Attribute> iteratorTypes,
AffineMap map) {
SmallVector<unsigned, 2> projectedDims;
for (auto attr : llvm::enumerate(iteratorTypes)) {
if (!isParallelIterator(attr.value()))
projectedDims.push_back(attr.index());
}
return getProjectedMap(map, projectedDims);
}
using FusableOpDependencesTy = llvm::MapVector<
Operation *,
SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 1>>;
/// Returns the positions of the loop in `op` that can be tiled based on the
/// operations that are to be fused with it. For example, in a
///
@ -508,12 +532,58 @@ Optional<FusionInfo> mlir::linalg::fuseProducerOfTensor(OpBuilder &b,
///
/// if the producer of %a needs to be fused with this op, only the `i` loop of
/// the matmul can be tiled while fusing. If producer of %a, and %b are to be
/// fused, then no loops can be tiled while fusing.
static DenseSet<unsigned> collectTileAndFuseLoops(
LinalgOp op, ArrayRef<LinalgDependenceGraph::LinalgDependenceGraphElem>
fusableDependences) {
// 1. Only parallel loops can be used for tile + fuse. Find the number of
// common outer parallel loops between the op and its producers being fused.
/// fused, then no loops can be tiled while fusing. The conditions used are:
/// 1. Only parallel loops can be used for tile + fuse. Find the number of
/// common outer parallel loops between the op and its producers being fused.
/// 2. Of the parallel loops only some can be fused. Only those loops can be
/// fused such where the fusable loops iteration space only touches one tile
/// of the fused operation. This is because the producer (which is writing
/// the fused subview) has update semantics. To compute this,
/// a. Find the mapping from iterations in the consumer that write to the
/// same location as the iterations in the producer. To do so use
/// - indexing map of the fused view in the consumer : consumerIndexMap
/// - indexing map of the fused view in the producer : producerIndexMap
/// consumerLoopToProducerLoop =
/// inverse(producerIndexMap).compose(consumerIndexMap)
///
/// Since an inverse computation is needed, we need to consider the projection
/// of the producerIndexMap w.r.t the parallel loops. The actual fusable loops
/// are the dimensions of the consumerLoopToProducerLoop map that correspond to
/// parallel loops and appear in the result of the map
///
/// Example 1:
/// linalg.fill(%c, %cst)
/// linalg.matmul ins(%a, %b) outs(%c)
/// Number of parallel loops : 2
/// producerIndexMap = affine_map<(i, j) ->(i , j)>
/// consumerIndexMap = affine_map<(i, j, k) -> (i, j)>
/// consumerLoopToProducerLoop = affine_map<(i, j, k) -> (i, j)>
/// Fused dimensions : i, j
///
/// Example 2:
/// linalg.matmul ins(%a, %b) outs(%c)
/// linalg.generic {indexing_maps = [affine_map<(i, j) -> (j, i)>, ...
/// iterator_types = ["parallel", "parallel"]}
/// ins(%c) ...
///
/// Number of parallel loops = 2:
/// producerIndexMap (projected to parallel loops) =
/// affine_map<(i, j) -> (i, j)>
/// consumerLoopToProducerLoop2 = affine_map<(i, j) -> (j, i)>
/// Fused dimensions : i, j
///
/// Example 3:
/// linalg.copy(%s, %b)
/// linalg.matmul ins(%a, %b) outs(%c)
///
/// Number of parallel loops = 2
/// produceIndexMap : affine_map<(i, j) -> (i, j)>
/// consumerLoopToProduceLoops = affine_map<(i, j, k) -> (k, j)>
/// submap with only parallel loops = affine_map<(i, j) -> (j)>
/// Fused dimensions : j
static std::set<unsigned>
collectTileAndFuseLoops(LinalgOp op,
const FusableOpDependencesTy &fusableDependences) {
auto getNumOuterParallelLoops = [](LinalgOp linalgOp) {
return linalgOp.iterator_types()
.getValue()
@ -524,135 +594,149 @@ static DenseSet<unsigned> collectTileAndFuseLoops(
.size();
};
LLVM_DEBUG({
llvm::dbgs() << "Op : ";
op.getOperation()->print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
llvm::dbgs() << "\n";
});
size_t numOuterParallelLoops = getNumOuterParallelLoops(op);
for (auto dependence : fusableDependences) {
linalg::LinalgOp producer = cast<linalg::LinalgOp>(dependence.first);
numOuterParallelLoops =
std::min(numOuterParallelLoops, getNumOuterParallelLoops(cast<LinalgOp>(
dependence.dependentOpView.op)));
std::min(numOuterParallelLoops, getNumOuterParallelLoops(producer));
}
// Need to compute what tiled loops can be "fused". Given the precondition
// that all indexing map for the producer view is a projected permutation, we
// can assert that the producer iterates over the dimensions of the "fused
// view" only once. To be used a fused loop the producer should use this loop
// to access the fused view. For example, consider
//
// ```
// linalg.add ins(%a, %b) outs(%c)
// linalg.matmul ins(%d, %c) outs(%e)
// ```
//
// if `linalg.add` has the semantics of `c = a + b`, then the following
// tile+fuse code is correct.
//
// ```
// for j ... += TSj
// %sa = subview %a[0, %j][...]
// %sb = subview %b[0, %j][...]
// %sc = subview %c[0, %j][...]
// %sd = subview %d[0, 0][...]
// %se = subview %e[0, %j][...]
// linalg.add ins(%sa, %sb) outs(%sc)
// linalg.matmul ins(%sd, %sc) outs(%se)
// ```
//
// On the other hand tiling along i would be incorrect
//
// ```
// for %i .. += TSi
// %sa = subview %a[%i, 0][...]
// %sb = subview %b[%i, 0][...]
// %sc = subview %c[%i, 0][...]
// %sc2 = subview %c[0, 0][...]
// %sd = subview %d[%i, 0][...]
// %se = subview %e[%i, 0][...]
// linalg.add ins(%sa, %sb) outs(%sc)
// linalg.matmul ins(%sd, %sc2) outs(%se)
// ```
//
// The write to the subview `%sc` in `linalg.add` is performed after the read
// from it using `%sc2` violating the RAW dependence of the original code. To
// find such loops indexing map of the fused view in the consumer op is
// used. For the above example, this indexing map is
//
// affine_map<(d0, d1, d2) -> (d2, d1)>
//
// Since d0 is not in the result expressions of this map, it is not treated as
// tile + fuse loop, (but d1 is).
//
// TODO: The above is probably restrictive and there might be a generalization
// of these that might allow for more fusion opportunities. Explore based on
// needs.
SmallVector<DenseSet<unsigned>, 1> commonTilableLoops;
std::set<unsigned> fusableLoops;
auto range = llvm::seq<unsigned>(0, numOuterParallelLoops);
fusableLoops.insert(range.begin(), range.end());
for (auto dependence : fusableDependences) {
unsigned consumerIdx =
op.getIndexOfShapedOperand(dependence.indexingView).getValue();
AffineMap consumerAccess = op.getIndexingMap(consumerIdx);
// Previously asserted that the consumerAccess map is a projected
// permutation, so all results are known to be AffineDimExprs. To remove
// this restriction walk the expression to find which dimensions of the
// consumer loop appear in the `consumerAccess`.
DenseSet<unsigned> positions;
for (auto expr : consumerAccess.getResults())
positions.insert(expr.cast<AffineDimExpr>().getPosition());
commonTilableLoops.emplace_back(std::move(positions));
LLVM_DEBUG({
llvm::dbgs() << "\t fusable :";
for (unsigned i : fusableLoops)
llvm::dbgs() << " " << i;
llvm::dbgs() << "\n";
});
linalg::LinalgOp producer = cast<linalg::LinalgOp>(dependence.first);
assert(!dependence.second.empty() &&
"unexpected producer but not dependences");
AffineMap producerIndexingMap = producer.getIndexingMap(
dependence.second.front().dependentOpView.operandIndex);
AffineMap prunedProducerIndexingMap = pruneReductionDimsFromMap(
producer.iterator_types().getValue(), producerIndexingMap);
if (!prunedProducerIndexingMap.isPermutation())
return {};
AffineMap consumerIndexingMap = op.getIndexingMap(
dependence.second.front().indexingOpView.operandIndex);
if (consumerIndexingMap.getNumResults() !=
prunedProducerIndexingMap.getNumResults())
return {};
LLVM_DEBUG({
llvm::dbgs() << "\t producerMap : ";
producerIndexingMap.print(llvm::dbgs());
llvm::dbgs() << " pruned : ";
prunedProducerIndexingMap.print(llvm::dbgs());
llvm::dbgs() << "\n";
llvm::dbgs() << "\t consumerMap : ";
consumerIndexingMap.print(llvm::dbgs());
llvm::dbgs() << "\n";
});
AffineMap invProducerIndexMap =
inversePermutation(prunedProducerIndexingMap);
if (!invProducerIndexMap)
return {};
AffineMap consumerLoopToProducerLoop =
invProducerIndexMap.compose(consumerIndexingMap);
LLVM_DEBUG({
llvm::dbgs() << "\t consumerLoopToProducerLoop : ";
consumerLoopToProducerLoop.print(llvm::dbgs());
});
std::set<unsigned> candidates;
for (AffineExpr expr : consumerLoopToProducerLoop.getResults()) {
AffineDimExpr dimExpr = expr.dyn_cast<AffineDimExpr>();
if (!dimExpr)
continue;
unsigned position = dimExpr.getPosition();
if (fusableLoops.count(position))
candidates.insert(position);
}
LLVM_DEBUG({
llvm::dbgs() << "\t candidates :";
for (unsigned i : candidates)
llvm::dbgs() << " " << i;
llvm::dbgs() << "\n";
});
if (candidates.empty())
return {};
std::swap(candidates, fusableLoops);
}
// 2. Of the outer parallel loops, only those loops can be tiled + fused as
// computed above for all the fused dependences can be used to tile and fuse.
DenseSet<unsigned> tilableParallelLoops;
for (auto index : llvm::seq<unsigned>(0, numOuterParallelLoops)) {
if (llvm::all_of(commonTilableLoops,
[&](const DenseSet<unsigned> &tilableLoops) {
return tilableLoops.count(index);
}))
tilableParallelLoops.insert(index);
}
return tilableParallelLoops;
return fusableLoops;
}
/// Find all dependences that are to be fusable.
static Optional<
SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 1>>
static FusableOpDependencesTy
findAllFusableDependences(LinalgOp op,
const LinalgDependenceGraph &dependenceGraph,
const LinalgFusionOptions &fusionOptions) {
SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 1>
fusableDependences;
for (auto operand : llvm::enumerate(op.getInputsAndOutputBuffers())) {
if (fusionOptions.indicesToFuse &&
!fusionOptions.indicesToFuse->count(operand.index()))
continue;
Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
fusableDependence =
findFusableProducer(op, operand.index(), dependenceGraph);
FusableOpDependencesTy fusableDependences;
// TODO: Currently fusion would not be legal if the fusable dependence is to
// the same producer but different indexing map in the consumer. Fix this, but
// in the meanwhile disallow such a fusion.
DenseMap<Operation *, AffineMap> fusedProducerIndexingMap;
for (auto operandIndex : fusionOptions.indicesToFuse) {
auto fusableDependence =
findFusableProducer(op, operandIndex, dependenceGraph);
if (!fusableDependence)
continue;
return FusableOpDependencesTy{};
LinalgOp producerOp = cast<LinalgOp>(fusableDependence->dependentOpView.op);
// Do not fuse dependences that are to operations not in the same basic
// block. This avoid moving fused operations across loops that might
// themselves carry dependency making the fusion illegal.
if (producerOp.getOperation()->getBlock() !=
op.getOperation()->getBlock()) {
op.emitRemark("unhandled fusion of ops in different basic blocks");
return FusableOpDependencesTy{};
}
// Make sure that the indexing map of the view used for fusion in the
// producer is a projected permutation.
LinalgOp producerOp = cast<LinalgOp>(fusableDependence->dependentOpView.op);
Value producerView = fusableDependence->dependentOpView.view;
unsigned producerIdx =
producerOp.getIndexOfOutputBuffer(producerView).getValue();
AffineMap producerMap = producerOp.getOutputIndexingMap(producerIdx);
unsigned producerIdx = fusableDependence->dependentOpView.operandIndex;
AffineMap producerMap = producerOp.getIndexingMap(producerIdx);
if (!producerMap.isProjectedPermutation()) {
op.emitError("unhandled non permutation indexing map for fused view in "
"producer for operand at index ")
<< operand.index();
return llvm::None;
op.emitRemark("unhandled non permutation indexing map for fused view in "
"producer for operand at index ")
<< operandIndex;
return FusableOpDependencesTy{};
}
Value consumerView = fusableDependence->indexingView;
unsigned consumerIdx = op.getIndexOfShapedOperand(consumerView).getValue();
if (!op.getIndexingMap(consumerIdx).isProjectedPermutation()) {
op.emitError(
unsigned consumerIdx = fusableDependence->indexingOpView.operandIndex;
AffineMap consumerMap = op.getIndexingMap(consumerIdx);
if (!consumerMap.isProjectedPermutation()) {
op.emitRemark(
"unhandled case where indexing map for fused view in the consumer is "
"not a projected permuration while fusing at index ")
<< operand.index();
return llvm::None;
"not a projected permutation while fusing at index ")
<< operandIndex;
return FusableOpDependencesTy{};
}
fusableDependences.push_back(*fusableDependence);
if (!fusionOptions.indicesToFuse)
break;
// Check if the producer is already a fusion candidate. Cannot fuse this
// dependence if it has a different indexing map when used in the consumer.
if (fusedProducerIndexingMap.count(producerOp.getOperation()) &&
fusedProducerIndexingMap[producerOp.getOperation()] != consumerMap) {
op.emitRemark("unhandled fusion to the same producer but with different "
"indexing maps");
return FusableOpDependencesTy{};
}
fusedProducerIndexingMap[producerOp.getOperation()] = consumerMap;
fusableDependences[producerOp.getOperation()].push_back(*fusableDependence);
}
return fusableDependences;
}
@ -682,13 +766,10 @@ tileAndFuseLinalgOpsImpl(PatternRewriter &rewriter, LinalgOp op,
ScopedContext scope(rewriter, op.getLoc());
// Find all the producers.
Optional<SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 1>>
fusableDependencesOpt =
findAllFusableDependences(op, dependenceGraph, fusionOptions);
if (!fusableDependencesOpt)
FusableOpDependencesTy fusableDependences =
findAllFusableDependences(op, dependenceGraph, fusionOptions);
if (fusableDependences.empty())
return llvm::None;
ArrayRef<LinalgDependenceGraph::LinalgDependenceGraphElem> fusableDependences(
*fusableDependencesOpt);
// Enforce the convention that "tiling by zero" skips tiling a particular
// dimension. This convention is significantly simpler to handle instead of
@ -704,12 +785,12 @@ tileAndFuseLinalgOpsImpl(PatternRewriter &rewriter, LinalgOp op,
TiledAndFusedLinalgOps ret;
// Find the loops that can be tiled and fused.
DenseSet<unsigned> tileFuseLoops =
std::set<unsigned> tileFuseLoops =
collectTileAndFuseLoops(op, fusableDependences);
// If there are no fusable dependences or there are no tile+fusable loops,
// just return.
if (fusableDependences.empty() || tileFuseLoops.empty()) {
if (tileFuseLoops.empty()) {
return llvm::None;
}
@ -752,15 +833,15 @@ tileAndFuseLinalgOpsImpl(PatternRewriter &rewriter, LinalgOp op,
rewriter.setInsertionPoint(ret.op);
// Fuse the operands.
for (auto producer : enumerate(fusableDependences)) {
LinalgOp producerOp = cast<LinalgOp>(producer.value().dependentOpView.op);
for (auto dependence : fusableDependences) {
LinalgOp producerOp = cast<LinalgOp>(dependence.first);
unsigned producerIdx =
producerOp.getIndexOfOutputBuffer(producer.value().dependentOpView.view)
.getValue();
dependence.second.front().dependentOpView.operandIndex;
unsigned consumerIdx =
op.getIndexOfShapedOperand(producer.value().indexingView).getValue();
LinalgOp fusedOp =
fuse(rewriter, producerOp, producerIdx, ret.op, consumerIdx);
dependence.second.front().indexingOpView.operandIndex;
LinalgOp fusedOp = fuse(rewriter, producerOp,
producerOp.getOutputIndex(producerIdx).getValue(),
ret.op, consumerIdx);
ret.fusedProducers.push_back(fusedOp);
ret.originalProducers.push_back(producerOp);
}

View File

@ -12,6 +12,7 @@
#include "mlir/IR/StandardTypes.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Support/MathExtras.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/raw_ostream.h"
@ -450,6 +451,22 @@ AffineMap mlir::concatAffineMaps(ArrayRef<AffineMap> maps) {
maps.front().getContext());
}
AffineMap mlir::getProjectedMap(AffineMap map,
ArrayRef<unsigned> projectedDimensions) {
DenseSet<unsigned> projectedDims(projectedDimensions.begin(),
projectedDimensions.end());
MLIRContext *context = map.getContext();
SmallVector<AffineExpr, 4> resultExprs;
for (auto dim : enumerate(llvm::seq<unsigned>(0, map.getNumDims()))) {
if (!projectedDims.count(dim.value()))
resultExprs.push_back(getAffineDimExpr(dim.index(), context));
else
resultExprs.push_back(getAffineConstantExpr(0, context));
}
return map.compose(AffineMap::get(
map.getNumDims() - projectedDimensions.size(), 0, resultExprs, context));
}
//===----------------------------------------------------------------------===//
// MutableAffineMap.
//===----------------------------------------------------------------------===//

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt %s -test-linalg-fusion-transform-patterns -canonicalize -cse -split-input-file | FileCheck %s
// RUN: mlir-opt %s -test-linalg-fusion-transform-patterns -canonicalize -cse -split-input-file -verify-diagnostics | FileCheck %s
module {
func @basic_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
@ -295,3 +295,121 @@ module {
// CHECK: }
// CHECK: linalg.matmul
// CHECK-SAME: __internal_linalg_transform__ = "after_lhs_fusion_original"
// -----
module {
func @matmul_plus_matmul(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
%arg2: memref<?x?xf32>) {
%c0 = constant 0 : index
%c1 = constant 1 : index
%0 = dim %arg2, %c0 : memref<?x?xf32>
%1 = dim %arg2, %c1 : memref<?x?xf32>
%2 = alloc(%0, %1) : memref<?x?xf32>
linalg.matmul ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
outs(%2 : memref<?x?xf32>)
linalg.generic
{indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"],
__internal_linalg_transform__ = "transpose_fusion"}
ins(%2, %2 : memref<?x?xf32>, memref<?x?xf32>)
outs(%arg2 : memref<?x?xf32>) {
^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32) :
%3 = addf %arg3, %arg4 : f32
linalg.yield %3 : f32
}
return
}
}
// CHECK: func @matmul_plus_matmul
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref<?x?xf32>
// CHECK: %[[T2:.+]] = alloc(%{{.*}}, %{{.*}}) : memref<?x?xf32>
// CHECK: linalg.matmul
// CHECK-SAME: after_transpose_fusion_original
// CHECK: scf.parallel (%[[ARG3:[a-zA-Z0-9_]+]], %[[ARG4:.[a-zA-Z0-9_]+]])
// CHECK: %[[T5:.+]] = subview %[[T2]][%[[ARG3]], %[[ARG4]]]
// CHECK: %[[T6:.+]] = subview %[[ARG2]][%[[ARG3]], %[[ARG4]]]
// CHECK: %[[T8:.+]] = subview %[[ARG0]][%[[ARG3]], 0]
// CHECK: %[[T9:.+]] = subview %[[ARG1]][0, %[[ARG4]]]
// CHECK: linalg.matmul
// CHECK-SAME: after_transpose_fusion_producer
// CHECK-SAME: ins(%[[T8]], %[[T9]]
// CHECK-SAME: outs(%[[T5]]
// CHECK-NOT: linalg.matmul
// CHECK: linalg.generic
// CHECK-SAME: ins(%[[T5]], %[[T5]]
// CHECK-SAME: outs(%[[T6]]
// CHECK-SAME: after_transpose_fusion
// -----
module {
func @matmul_plus_transpose_matmul(%arg0: memref<?x?xf32>,
%arg1: memref<?x?xf32>,
%arg2: memref<?x?xf32>) {
%c0 = constant 0 : index
%c1 = constant 1 : index
%0 = dim %arg2, %c0 : memref<?x?xf32>
%1 = dim %arg2, %c1 : memref<?x?xf32>
%2 = alloc(%0, %1) : memref<?x?xf32>
linalg.matmul ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
outs(%2 : memref<?x?xf32>)
// expected-remark @+1 {{unhandled fusion to the same producer but with different indexing maps}}
linalg.generic
{indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d1, d0)>,
affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"],
__internal_linalg_transform__ = "transpose_fusion"}
ins(%2, %2 : memref<?x?xf32>, memref<?x?xf32>)
outs(%arg2 : memref<?x?xf32>) {
^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32) :
%3 = addf %arg3, %arg4 : f32
linalg.yield %3 : f32
}
return
}
}
// -----
#map0 = affine_map<(d0)[s0] -> (32, -d0 + s0)>
#map1 = affine_map<(d0)[s0] -> (64, -d0 + s0)>
#map2 = affine_map<(d0)[s0] -> (16, -d0 + s0)>
#map3 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
module {
func @basic_no_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
%arg2: memref<?x?xf32>) {
%c0 = constant 0 : index
%c1 = constant 1 : index
%c2 = constant 2 : index
%c32 = constant 32 : index
%c64 = constant 64 : index
%c16 = constant 16 : index
%cst = constant 0.000000e+00 : f32
linalg.fill(%arg2, %cst) : memref<?x?xf32>, f32
%0 = dim %arg0, %c0 : memref<?x?xf32>
%1 = dim %arg1, %c1 : memref<?x?xf32>
%2 = dim %arg0, %c1 : memref<?x?xf32>
scf.parallel (%arg3, %arg4) = (%c0, %c0) to (%0, %1) step (%c32, %c64) {
scf.for %arg5 = %c0 to %2 step %c16 {
%3 = affine.min #map0(%arg3)[%0]
%4 = affine.min #map1(%arg4)[%1]
%5 = affine.min #map2(%arg5)[%2]
%6 = subview %arg0[%arg3, %arg5] [%3, %5] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map3>
%7 = subview %arg1[%arg5, %arg4] [%5, %4] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map3>
%8 = subview %arg2[%arg3, %arg4] [%3, %4] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map3>
// expected-remark @+1 {{unhandled fusion of ops in different basic blocks}}
linalg.matmul {__internal_linalg_transform__ = "basic_fusion"}
ins(%6, %7 : memref<?x?xf32, #map3>, memref<?x?xf32, #map3>)
outs(%8 : memref<?x?xf32, #map3>)
}
scf.yield
}
return
}
}

View File

@ -43,7 +43,7 @@ static void fillFusionPatterns(MLIRContext *context,
LinalgTilingOptions()
.setTileSizes({32, 64, 16})
.setLoopType(LinalgTilingLoopType::ParallelLoops),
LinalgFusionOptions(),
LinalgFusionOptions().setIndicesToFuse({2}),
LinalgMarker(Identifier::get("basic_fusion", context),
Identifier::get("after_basic_fusion", context)),
LinalgMarker(ArrayRef<Identifier>(),
@ -91,6 +91,19 @@ static void fillFusionPatterns(MLIRContext *context,
LinalgMarker(
ArrayRef<Identifier>(),
Identifier::get("after_two_operand_fusion_original", context)));
patterns.insert<LinalgTileAndFusePattern<GenericOp>>(
context, dependenceGraph,
LinalgTilingOptions().setTileSizes({32, 64}).setLoopType(
LinalgTilingLoopType::ParallelLoops),
LinalgFusionOptions().setIndicesToFuse({0, 1}),
LinalgMarker(Identifier::get("transpose_fusion", context),
Identifier::get("after_transpose_fusion", context)),
LinalgMarker(ArrayRef<Identifier>(),
Identifier::get("after_transpose_fusion_producer", context)),
LinalgMarker(
ArrayRef<Identifier>(),
Identifier::get("after_transpose_fusion_original", context)));
}
static void applyFusionPatterns(MLIRContext *context, FuncOp funcOp) {