forked from OSchip/llvm-project
[mlir][Linalg] Improve the logic to perform tile and fuse with better dependence tracking.
This change does two main things 1) An operation might have multiple dependences to the same producer. Not tracking them correctly can result in incorrect code generation with fusion. To rectify this the dependence tracking needs to also have the operand number in the consumer. 2) Improve the logic used to find the fused loops making it easier to follow. The only constraint for fusion is that linalg ops (on buffers) have update semantics for the result. Fusion should be such that only one iteration of the fused loop (which is also a tiled loop) must touch only one (disjoint) tile of the output. This could be relaxed by allowing for recomputation that is the default when oeprands are tensors, or can be made legal with promotion of the fused view (in future). Differential Revision: https://reviews.llvm.org/D90579
This commit is contained in:
parent
ad376657c1
commit
5ca20851e4
|
@ -45,7 +45,7 @@ class LinalgDependenceGraph {
|
||||||
public:
|
public:
|
||||||
struct LinalgOpView {
|
struct LinalgOpView {
|
||||||
Operation *op;
|
Operation *op;
|
||||||
Value view;
|
unsigned operandIndex;
|
||||||
};
|
};
|
||||||
struct LinalgDependenceGraphElem {
|
struct LinalgDependenceGraphElem {
|
||||||
// dependentOpView may be either:
|
// dependentOpView may be either:
|
||||||
|
@ -55,7 +55,7 @@ public:
|
||||||
// View in the op that is used to index in the graph:
|
// View in the op that is used to index in the graph:
|
||||||
// 1. src in the case of dependencesFromDstGraphs.
|
// 1. src in the case of dependencesFromDstGraphs.
|
||||||
// 2. dst in the case of dependencesIntoGraphs.
|
// 2. dst in the case of dependencesIntoGraphs.
|
||||||
Value indexingView;
|
LinalgOpView indexingOpView;
|
||||||
};
|
};
|
||||||
using LinalgDependences = SmallVector<LinalgDependenceGraphElem, 8>;
|
using LinalgDependences = SmallVector<LinalgDependenceGraphElem, 8>;
|
||||||
using DependenceGraph = DenseMap<Operation *, LinalgDependences>;
|
using DependenceGraph = DenseMap<Operation *, LinalgDependences>;
|
||||||
|
|
|
@ -555,7 +555,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
|
||||||
>,
|
>,
|
||||||
InterfaceMethod<
|
InterfaceMethod<
|
||||||
/*desc=*/[{
|
/*desc=*/[{
|
||||||
Return the position of the shaped operand in the operand list.
|
Return the first position of the shaped operand in the operand list.
|
||||||
}],
|
}],
|
||||||
/*retTy=*/"Optional<unsigned>",
|
/*retTy=*/"Optional<unsigned>",
|
||||||
/*methodName=*/"getIndexOfShapedOperand",
|
/*methodName=*/"getIndexOfShapedOperand",
|
||||||
|
@ -573,6 +573,67 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
|
||||||
return llvm::None;
|
return llvm::None;
|
||||||
}]
|
}]
|
||||||
>,
|
>,
|
||||||
|
InterfaceMethod<
|
||||||
|
/*desc=*/[{
|
||||||
|
Returns the operand index given the input index. Returns None
|
||||||
|
of the input index is invalid.
|
||||||
|
}],
|
||||||
|
/*retTy=*/"Optional<unsigned>",
|
||||||
|
/*methodName=*/"getOperandIndexForInputIndex",
|
||||||
|
/*args=*/(ins "unsigned":$input_index),
|
||||||
|
/*methodBody=*/"",
|
||||||
|
/*defaultImplementation=*/[{
|
||||||
|
if (input_index >= $_op.getNumInputs())
|
||||||
|
return llvm::None;
|
||||||
|
return input_index;
|
||||||
|
}]
|
||||||
|
>,
|
||||||
|
InterfaceMethod<
|
||||||
|
/*desc=*/[{
|
||||||
|
Returns the operand index given the output index. Returns None
|
||||||
|
of the output index is invalid.
|
||||||
|
}],
|
||||||
|
/*retTy=*/"Optional<unsigned>",
|
||||||
|
/*methodName=*/"getOperandIndexForOutputIndex",
|
||||||
|
/*args=*/(ins "unsigned":$output_index),
|
||||||
|
/*methodBody=*/"",
|
||||||
|
/*defaultImplementation=*/[{
|
||||||
|
if (output_index >= $_op.getNumOutputs())
|
||||||
|
return llvm::None;
|
||||||
|
return output_index + $_op.getNumInputs();
|
||||||
|
}]
|
||||||
|
>,
|
||||||
|
InterfaceMethod<
|
||||||
|
/*desc=*/[{
|
||||||
|
Returns the input index given the operand index. Return None
|
||||||
|
if the operand index doesnt corresponding to an input.
|
||||||
|
}],
|
||||||
|
/*retTy=*/"Optional<unsigned>",
|
||||||
|
/*methodName=*/"getInputIndex",
|
||||||
|
/*args=*/(ins "unsigned":$operand_index),
|
||||||
|
/*methodBody=*/"",
|
||||||
|
/*defaultImplementation=*/[{
|
||||||
|
if (operand_index >= $_op.getNumInputs())
|
||||||
|
return llvm::None;
|
||||||
|
return operand_index;
|
||||||
|
}]
|
||||||
|
>,
|
||||||
|
InterfaceMethod<
|
||||||
|
/*desc=*/[{
|
||||||
|
Returns the output index given the operand index. Return None
|
||||||
|
if the operand index doesnt corresponding to an output.
|
||||||
|
}],
|
||||||
|
/*retTy=*/"Optional<unsigned>",
|
||||||
|
/*methodName=*/"getOutputIndex",
|
||||||
|
/*args=*/(ins "unsigned":$operand_index),
|
||||||
|
/*methodBody=*/"",
|
||||||
|
/*defaultImplementation=*/[{
|
||||||
|
if (operand_index < $_op.getNumInputs() ||
|
||||||
|
operand_index >= $_op.getNumInputs() + $_op.getNumOutputs())
|
||||||
|
return llvm::None;
|
||||||
|
return operand_index - $_op.getNumInputs();
|
||||||
|
}]
|
||||||
|
>,
|
||||||
|
|
||||||
//===------------------------------------------------------------------===//
|
//===------------------------------------------------------------------===//
|
||||||
// Other interface methods.
|
// Other interface methods.
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
#include "mlir/Transforms/Bufferize.h"
|
#include "mlir/Transforms/Bufferize.h"
|
||||||
#include "llvm/ADT/SmallBitVector.h"
|
#include "llvm/ADT/SmallBitVector.h"
|
||||||
|
#include "llvm/ADT/SmallSet.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
class BufferizeTypeConverter;
|
class BufferizeTypeConverter;
|
||||||
|
@ -429,12 +430,10 @@ struct LinalgTilingPattern : public LinalgBaseTilingPattern {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct LinalgFusionOptions {
|
struct LinalgFusionOptions {
|
||||||
/// Optional list of operands indices to use for fusion. When unspecified,
|
/// List of operands indices to use for fusion.
|
||||||
/// only one fusion is done, i.e., the pattern returns after the first fusion.
|
llvm::SmallSet<unsigned, 1> indicesToFuse = {};
|
||||||
Optional<DenseSet<unsigned>> indicesToFuse = None;
|
|
||||||
LinalgFusionOptions &setIndicesToFuse(ArrayRef<int64_t> operands) {
|
LinalgFusionOptions &setIndicesToFuse(ArrayRef<int64_t> operands) {
|
||||||
indicesToFuse = DenseSet<unsigned>();
|
indicesToFuse.insert(operands.begin(), operands.end());
|
||||||
indicesToFuse->insert(operands.begin(), operands.end());
|
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -323,6 +323,9 @@ AffineMap inversePermutation(AffineMap map);
|
||||||
/// ```
|
/// ```
|
||||||
AffineMap concatAffineMaps(ArrayRef<AffineMap> maps);
|
AffineMap concatAffineMaps(ArrayRef<AffineMap> maps);
|
||||||
|
|
||||||
|
AffineMap getProjectedMap(AffineMap map,
|
||||||
|
ArrayRef<unsigned> projectedDimensions);
|
||||||
|
|
||||||
inline raw_ostream &operator<<(raw_ostream &os, AffineMap map) {
|
inline raw_ostream &operator<<(raw_ostream &os, AffineMap map) {
|
||||||
map.print(os);
|
map.print(os);
|
||||||
return os;
|
return os;
|
||||||
|
|
|
@ -108,12 +108,14 @@ LinalgDependenceGraph::LinalgDependenceGraph(Aliases &aliases,
|
||||||
void LinalgDependenceGraph::addDependenceElem(DependenceType dt,
|
void LinalgDependenceGraph::addDependenceElem(DependenceType dt,
|
||||||
LinalgOpView indexingOpView,
|
LinalgOpView indexingOpView,
|
||||||
LinalgOpView dependentOpView) {
|
LinalgOpView dependentOpView) {
|
||||||
LLVM_DEBUG(dbgs() << "\nAdd dep type " << getDependenceTypeStr(dt) << ":\t"
|
LLVM_DEBUG(dbgs() << "\nAdd dep type " << getDependenceTypeStr(dt) << ":\t ("
|
||||||
<< *indexingOpView.op << " -> " << *dependentOpView.op);
|
<< *indexingOpView.op << ", " << indexingOpView.operandIndex
|
||||||
|
<< ") -> \n\t\t(" << *dependentOpView.op << ", "
|
||||||
|
<< dependentOpView.operandIndex << ")");
|
||||||
dependencesFromGraphs[dt][indexingOpView.op].push_back(
|
dependencesFromGraphs[dt][indexingOpView.op].push_back(
|
||||||
LinalgDependenceGraphElem{dependentOpView, indexingOpView.view});
|
LinalgDependenceGraphElem{dependentOpView, indexingOpView});
|
||||||
dependencesIntoGraphs[dt][dependentOpView.op].push_back(
|
dependencesIntoGraphs[dt][dependentOpView.op].push_back(
|
||||||
LinalgDependenceGraphElem{indexingOpView, dependentOpView.view});
|
LinalgDependenceGraphElem{indexingOpView, dependentOpView});
|
||||||
}
|
}
|
||||||
|
|
||||||
LinalgDependenceGraph::dependence_range
|
LinalgDependenceGraph::dependence_range
|
||||||
|
@ -147,39 +149,55 @@ LinalgDependenceGraph::getDependencesInto(
|
||||||
}
|
}
|
||||||
|
|
||||||
void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) {
|
void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) {
|
||||||
for (auto srcView : src.getOutputBuffers()) { // W
|
for (auto srcView : llvm::enumerate(src.getOutputBuffers())) { // W
|
||||||
|
unsigned srcIndex =
|
||||||
|
src.getOperandIndexForOutputIndex(srcView.index()).getValue();
|
||||||
// RAW graph
|
// RAW graph
|
||||||
for (auto dstView : dst.getInputBuffers()) { // R
|
for (auto dstView : llvm::enumerate(dst.getInputBuffers())) { // R
|
||||||
if (aliases.alias(srcView, dstView)) { // if alias, fill RAW
|
if (aliases.alias(srcView.value(),
|
||||||
|
dstView.value())) { // if alias, fill RAW
|
||||||
|
unsigned dstIndex =
|
||||||
|
dst.getOperandIndexForInputIndex(dstView.index()).getValue();
|
||||||
addDependenceElem(DependenceType::RAW,
|
addDependenceElem(DependenceType::RAW,
|
||||||
LinalgOpView{src.getOperation(), srcView},
|
LinalgOpView{src.getOperation(), srcIndex},
|
||||||
LinalgOpView{dst.getOperation(), dstView});
|
LinalgOpView{dst.getOperation(), dstIndex});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// WAW graph
|
// WAW graph
|
||||||
for (auto dstView : dst.getOutputBuffers()) { // W
|
for (auto dstView : llvm::enumerate(dst.getOutputBuffers())) { // W
|
||||||
if (aliases.alias(srcView, dstView)) { // if alias, fill WAW
|
if (aliases.alias(srcView.value(),
|
||||||
|
dstView.value())) { // if alias, fill WAW
|
||||||
|
unsigned dstIndex =
|
||||||
|
dst.getOperandIndexForOutputIndex(dstView.index()).getValue();
|
||||||
addDependenceElem(DependenceType::WAW,
|
addDependenceElem(DependenceType::WAW,
|
||||||
LinalgOpView{src.getOperation(), srcView},
|
LinalgOpView{src.getOperation(), srcIndex},
|
||||||
LinalgOpView{dst.getOperation(), dstView});
|
LinalgOpView{dst.getOperation(), dstIndex});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (auto srcView : src.getInputBuffers()) { // R
|
for (auto srcView : llvm::enumerate(src.getInputBuffers())) { // R
|
||||||
|
unsigned srcIndex =
|
||||||
|
src.getOperandIndexForInputIndex(srcView.index()).getValue();
|
||||||
// RAR graph
|
// RAR graph
|
||||||
for (auto dstView : dst.getInputBuffers()) { // R
|
for (auto dstView : llvm::enumerate(dst.getInputBuffers())) { // R
|
||||||
if (aliases.alias(srcView, dstView)) { // if alias, fill RAR
|
if (aliases.alias(srcView.value(),
|
||||||
|
dstView.value())) { // if alias, fill RAR
|
||||||
|
unsigned dstIndex =
|
||||||
|
dst.getOperandIndexForInputIndex(dstView.index()).getValue();
|
||||||
addDependenceElem(DependenceType::RAR,
|
addDependenceElem(DependenceType::RAR,
|
||||||
LinalgOpView{src.getOperation(), srcView},
|
LinalgOpView{src.getOperation(), srcIndex},
|
||||||
LinalgOpView{dst.getOperation(), dstView});
|
LinalgOpView{dst.getOperation(), dstIndex});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// WAR graph
|
// WAR graph
|
||||||
for (auto dstView : dst.getOutputBuffers()) { // W
|
for (auto dstView : llvm::enumerate(dst.getOutputBuffers())) { // W
|
||||||
if (aliases.alias(srcView, dstView)) { // if alias, fill WAR
|
if (aliases.alias(srcView.value(),
|
||||||
|
dstView.value())) { // if alias, fill WAR
|
||||||
|
unsigned dstIndex =
|
||||||
|
dst.getOperandIndexForOutputIndex(dstView.index()).getValue();
|
||||||
addDependenceElem(DependenceType::WAR,
|
addDependenceElem(DependenceType::WAR,
|
||||||
LinalgOpView{src.getOperation(), srcView},
|
LinalgOpView{src.getOperation(), srcIndex},
|
||||||
LinalgOpView{dst.getOperation(), dstView});
|
LinalgOpView{dst.getOperation(), dstIndex});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -227,12 +245,16 @@ LinalgDependenceGraph::findOperationsWithCoveringDependences(
|
||||||
// Skip if not interleaved.
|
// Skip if not interleaved.
|
||||||
if (interimPos >= dstPos || interimPos <= srcPos)
|
if (interimPos >= dstPos || interimPos <= srcPos)
|
||||||
continue;
|
continue;
|
||||||
if (view && !aliases.alias(view, dependence.indexingView))
|
linalg::LinalgOp consumer =
|
||||||
|
cast<linalg::LinalgOp>(dependence.indexingOpView.op);
|
||||||
|
Value consumerView =
|
||||||
|
consumer.getShapedOperand(dependence.indexingOpView.operandIndex);
|
||||||
|
if (view && !aliases.alias(view, consumerView))
|
||||||
continue;
|
continue;
|
||||||
auto *op = dependence.dependentOpView.op;
|
auto *op = dependence.dependentOpView.op;
|
||||||
LLVM_DEBUG(dbgs() << "\n***Found covering dependence of type "
|
LLVM_DEBUG(dbgs() << "\n***Found covering dependence of type "
|
||||||
<< getDependenceTypeStr(dt) << ": " << *src << " -> "
|
<< getDependenceTypeStr(dt) << ": " << *src << " -> "
|
||||||
<< *op << " on " << dependence.indexingView);
|
<< *op << " on " << consumerView);
|
||||||
res.push_back(op);
|
res.push_back(op);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,10 +24,12 @@
|
||||||
#include "mlir/IR/Dominance.h"
|
#include "mlir/IR/Dominance.h"
|
||||||
#include "mlir/Support/LLVM.h"
|
#include "mlir/Support/LLVM.h"
|
||||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
#include "llvm/ADT/SetVector.h"
|
#include "llvm/ADT/MapVector.h"
|
||||||
#include "llvm/Support/CommandLine.h"
|
#include "llvm/Support/CommandLine.h"
|
||||||
#include "llvm/Support/Debug.h"
|
#include "llvm/Support/Debug.h"
|
||||||
|
|
||||||
|
#include <set>
|
||||||
|
|
||||||
#define DEBUG_TYPE "linalg-fusion"
|
#define DEBUG_TYPE "linalg-fusion"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
@ -95,7 +97,7 @@ static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op,
|
||||||
for (auto en : llvm::enumerate(op.getShapedOperands())) {
|
for (auto en : llvm::enumerate(op.getShapedOperands())) {
|
||||||
unsigned shapedOperandIdx = en.index();
|
unsigned shapedOperandIdx = en.index();
|
||||||
AffineMap map = op.getIndexingMap(shapedOperandIdx);
|
AffineMap map = op.getIndexingMap(shapedOperandIdx);
|
||||||
LLVM_DEBUG(dbgs() << "shapedOperandIdx: " << shapedOperandIdx
|
LLVM_DEBUG(llvm::dbgs() << "shapedOperandIdx: " << shapedOperandIdx
|
||||||
<< " with indexingMap: " << map << "\n");
|
<< " with indexingMap: " << map << "\n");
|
||||||
SmallVector<Value, 4> offsets, sizes, strides;
|
SmallVector<Value, 4> offsets, sizes, strides;
|
||||||
inferShapeComponents(map, loopRanges, offsets, sizes, strides);
|
inferShapeComponents(map, loopRanges, offsets, sizes, strides);
|
||||||
|
@ -169,16 +171,18 @@ static ShapeDimension getShapeDefiningLoopRange(LinalgOp op,
|
||||||
for (auto en : llvm::enumerate(ios)) {
|
for (auto en : llvm::enumerate(ios)) {
|
||||||
unsigned idx = en.index();
|
unsigned idx = en.index();
|
||||||
auto map = maps[idx].cast<AffineMapAttr>().getValue();
|
auto map = maps[idx].cast<AffineMapAttr>().getValue();
|
||||||
LLVM_DEBUG(dbgs() << "getShapeDefiningLoopRange I/O idx: " << idx << "\n");
|
LLVM_DEBUG(llvm::dbgs()
|
||||||
LLVM_DEBUG(dbgs() << "getShapeDefiningLoopRange map: " << map << "\n");
|
<< "getShapeDefiningLoopRange I/O idx: " << idx << "\n");
|
||||||
|
LLVM_DEBUG(llvm::dbgs()
|
||||||
|
<< "getShapeDefiningLoopRange map: " << map << "\n");
|
||||||
Value shape = en.value();
|
Value shape = en.value();
|
||||||
SmallVector<Value, 8> shapeRanges(map.getNumResults(), nullptr);
|
SmallVector<Value, 8> shapeRanges(map.getNumResults(), nullptr);
|
||||||
for (auto en2 : llvm::enumerate(map.getResults())) {
|
for (auto en2 : llvm::enumerate(map.getResults())) {
|
||||||
if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) {
|
if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) {
|
||||||
LLVM_DEBUG(dbgs() << "getShapeDefiningLoopRange loopDepth: "
|
LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange loopDepth: "
|
||||||
<< loopDepth << "\n");
|
<< loopDepth << "\n");
|
||||||
LLVM_DEBUG(dbgs() << "getShapeDefiningLoopRange shape: " << shape
|
LLVM_DEBUG(llvm::dbgs()
|
||||||
<< "\n");
|
<< "getShapeDefiningLoopRange shape: " << shape << "\n");
|
||||||
return ShapeDimension{shape, static_cast<unsigned>(en2.index())};
|
return ShapeDimension{shape, static_cast<unsigned>(en2.index())};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -209,7 +213,7 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx,
|
||||||
// dimension.
|
// dimension.
|
||||||
// TODO: extend this with range inference.
|
// TODO: extend this with range inference.
|
||||||
AffineMap producerMap = producer.getOutputIndexingMap(producerIdx);
|
AffineMap producerMap = producer.getOutputIndexingMap(producerIdx);
|
||||||
LLVM_DEBUG(dbgs() << "Producer Idx: " << producerIdx
|
LLVM_DEBUG(llvm::dbgs() << "Producer Idx: " << producerIdx
|
||||||
<< ", producer map: " << producerMap << "\n");
|
<< ", producer map: " << producerMap << "\n");
|
||||||
|
|
||||||
unsigned nPar = producer.getNumParallelLoops();
|
unsigned nPar = producer.getNumParallelLoops();
|
||||||
|
@ -258,7 +262,7 @@ static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView,
|
||||||
assert(consumer.hasBufferSemantics() &&
|
assert(consumer.hasBufferSemantics() &&
|
||||||
"expected linalg op with buffer semantics");
|
"expected linalg op with buffer semantics");
|
||||||
if (producer.getNumOutputs() != 1) {
|
if (producer.getNumOutputs() != 1) {
|
||||||
LLVM_DEBUG(dbgs() << "\nNot structurally fusable (multi-output)");
|
LLVM_DEBUG(llvm::dbgs() << "\nNot structurally fusable (multi-output)");
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
// Only fuse when the producer block dominates.
|
// Only fuse when the producer block dominates.
|
||||||
|
@ -266,7 +270,7 @@ static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView,
|
||||||
if (!dom.dominates(producer.getOperation()->getBlock(),
|
if (!dom.dominates(producer.getOperation()->getBlock(),
|
||||||
consumer.getOperation()->getBlock())) {
|
consumer.getOperation()->getBlock())) {
|
||||||
LLVM_DEBUG(
|
LLVM_DEBUG(
|
||||||
dbgs()
|
llvm::dbgs()
|
||||||
<< "\nNot structurally fusable (producer block does not dominate)");
|
<< "\nNot structurally fusable (producer block does not dominate)");
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -284,13 +288,13 @@ bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph,
|
||||||
// Make some simple structural checks that alleviate the need for more
|
// Make some simple structural checks that alleviate the need for more
|
||||||
// complex analyses.
|
// complex analyses.
|
||||||
if (!isStructurallyFusableProducer(producer, consumedView, consumer)) {
|
if (!isStructurallyFusableProducer(producer, consumedView, consumer)) {
|
||||||
LLVM_DEBUG(dbgs() << "\n***Not static last write due to structure:\t"
|
LLVM_DEBUG(llvm::dbgs() << "\n***Not static last write due to structure:\t"
|
||||||
<< *producer.getOperation());
|
<< *producer.getOperation());
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
// Check for any interleaved write to consumedView.
|
// Check for any interleaved write to consumedView.
|
||||||
if (!graph.findCoveringWrites(producer, consumer, consumedView).empty()) {
|
if (!graph.findCoveringWrites(producer, consumer, consumedView).empty()) {
|
||||||
LLVM_DEBUG(dbgs() << "\n***Not fusable due to interleaved write:\t"
|
LLVM_DEBUG(llvm::dbgs() << "\n***Not fusable due to interleaved write:\t"
|
||||||
<< *producer.getOperation());
|
<< *producer.getOperation());
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -309,7 +313,8 @@ bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph,
|
||||||
// Check for any fusion-preventing dependence to any shape read/written that
|
// Check for any fusion-preventing dependence to any shape read/written that
|
||||||
// would violate dependences.
|
// would violate dependences.
|
||||||
if (!graph.findCoveringDependences(producer, consumer).empty()) {
|
if (!graph.findCoveringDependences(producer, consumer).empty()) {
|
||||||
LLVM_DEBUG(dbgs() << "\n***Not fusable due to an interleaved dependence:\t"
|
LLVM_DEBUG(llvm::dbgs()
|
||||||
|
<< "\n***Not fusable due to an interleaved dependence:\t"
|
||||||
<< *producer.getOperation());
|
<< *producer.getOperation());
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -360,26 +365,33 @@ findFusableProducer(LinalgOp consumer, unsigned consumerIdx,
|
||||||
LinalgDependenceGraph::DependenceType::RAW,
|
LinalgDependenceGraph::DependenceType::RAW,
|
||||||
LinalgDependenceGraph::DependenceType::WAW,
|
LinalgDependenceGraph::DependenceType::WAW,
|
||||||
}) {
|
}) {
|
||||||
for (auto dependence :
|
for (auto dependence : llvm::make_filter_range(
|
||||||
dependenceGraph.getDependencesInto(consumer, depType)) {
|
dependenceGraph.getDependencesInto(consumer, depType),
|
||||||
|
[consumerIdx](
|
||||||
|
LinalgDependenceGraph::LinalgDependenceGraphElem elem) {
|
||||||
|
return elem.indexingOpView.operandIndex == consumerIdx;
|
||||||
|
})) {
|
||||||
auto producer = cast<LinalgOp>(dependence.dependentOpView.op);
|
auto producer = cast<LinalgOp>(dependence.dependentOpView.op);
|
||||||
|
|
||||||
// Check that the dependence is indeed on the input `consumerIdx` view.
|
// Check that the dependence is indeed on the input `consumerIdx` view.
|
||||||
auto consumedView = dependence.indexingView;
|
auto consumedView =
|
||||||
|
consumer.getBuffer(dependence.indexingOpView.operandIndex);
|
||||||
if (!isSameSubView(consumer.getBuffer(consumerIdx), consumedView))
|
if (!isSameSubView(consumer.getBuffer(consumerIdx), consumedView))
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
// Consumer consumes this view, `isStructurallyFusableProducer` also
|
// Consumer consumes this view, `isStructurallyFusableProducer` also
|
||||||
// checks whether it is a strict subview of the producer view.
|
// checks whether it is a strict subview of the producer view.
|
||||||
auto producedView = dependence.dependentOpView.view;
|
auto producedView =
|
||||||
auto producerIdx =
|
producer.getBuffer(dependence.dependentOpView.operandIndex);
|
||||||
producer.getIndexOfOutputBuffer(producedView).getValue();
|
LLVM_DEBUG(llvm::dbgs()
|
||||||
// `consumerIdx` and `producerIdx` exist by construction.
|
<< "\n"
|
||||||
LLVM_DEBUG(dbgs() << "\n"
|
|
||||||
<< LinalgDependenceGraph::getDependenceTypeStr(depType)
|
<< LinalgDependenceGraph::getDependenceTypeStr(depType)
|
||||||
<< "producer: " << *producer.getOperation() << " view: "
|
<< "producer: " << *producer.getOperation()
|
||||||
<< producedView << " output index: " << producerIdx);
|
<< " view: " << producedView << " output index: "
|
||||||
(void)producerIdx;
|
<< dependence.dependentOpView.operandIndex -
|
||||||
|
producer.getNumInputs()
|
||||||
|
<< "\n");
|
||||||
|
(void)producedView;
|
||||||
|
|
||||||
// Simple fusability checks.
|
// Simple fusability checks.
|
||||||
if (!isFusableInto(dependenceGraph, consumer, consumedView, producer))
|
if (!isFusableInto(dependenceGraph, consumer, consumedView, producer))
|
||||||
|
@ -406,15 +418,16 @@ mlir::linalg::fuseProducerOfBuffer(OpBuilder &b, LinalgOp consumer,
|
||||||
producerOp.getOperation()->getBlock())
|
producerOp.getOperation()->getBlock())
|
||||||
return {};
|
return {};
|
||||||
|
|
||||||
Value producerView = fusableDependence->dependentOpView.view;
|
unsigned producerIdx = fusableDependence->dependentOpView.operandIndex -
|
||||||
Value consumerView = fusableDependence->indexingView;
|
producerOp.getNumInputs();
|
||||||
|
Value consumerView = consumer.getShapedOperand(consumerIdx);
|
||||||
|
|
||||||
// Must be a subview or a slice to guarantee there are loops we can fuse
|
// Must be a subview or a slice to guarantee there are loops we can fuse
|
||||||
// into.
|
// into.
|
||||||
auto subView = consumerView.getDefiningOp<SubViewOp>();
|
auto subView = consumerView.getDefiningOp<SubViewOp>();
|
||||||
auto slice = consumerView.getDefiningOp<SliceOp>();
|
auto slice = consumerView.getDefiningOp<SliceOp>();
|
||||||
if (!subView && !slice) {
|
if (!subView && !slice) {
|
||||||
LLVM_DEBUG(dbgs() << "\nNot fusable (not a subview or slice)");
|
LLVM_DEBUG(llvm::dbgs() << "\nNot fusable (not a subview or slice)");
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -422,11 +435,7 @@ mlir::linalg::fuseProducerOfBuffer(OpBuilder &b, LinalgOp consumer,
|
||||||
OpBuilder::InsertionGuard g(b);
|
OpBuilder::InsertionGuard g(b);
|
||||||
b.setInsertionPoint(consumer.getOperation());
|
b.setInsertionPoint(consumer.getOperation());
|
||||||
ScopedContext scope(b, consumer.getLoc());
|
ScopedContext scope(b, consumer.getLoc());
|
||||||
LLVM_DEBUG(dbgs() << "Fuse into consumer: " << *consumer << "\n");
|
LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " << *consumer << "\n");
|
||||||
Optional<unsigned> producerIdxOpt =
|
|
||||||
producerOp.getIndexOfOutputBuffer(producerView);
|
|
||||||
assert(producerIdxOpt.hasValue() && "incorrect operand index");
|
|
||||||
unsigned producerIdx = producerIdxOpt.getValue();
|
|
||||||
|
|
||||||
auto fusedProducer = fuse(b, producerOp, producerIdx, consumer, consumerIdx);
|
auto fusedProducer = fuse(b, producerOp, producerIdx, consumer, consumerIdx);
|
||||||
return FusionInfo{producerOp, fusedProducer};
|
return FusionInfo{producerOp, fusedProducer};
|
||||||
|
@ -470,7 +479,7 @@ Optional<FusionInfo> mlir::linalg::fuseProducerOfTensor(OpBuilder &b,
|
||||||
// Must be a subtensor to guarantee there are loops we can fuse into.
|
// Must be a subtensor to guarantee there are loops we can fuse into.
|
||||||
auto subTensor = inputTensor.getDefiningOp<SubTensorOp>();
|
auto subTensor = inputTensor.getDefiningOp<SubTensorOp>();
|
||||||
if (!subTensor || !producerOp) {
|
if (!subTensor || !producerOp) {
|
||||||
LLVM_DEBUG(dbgs() << "\nNot fusable (not a subtensor)");
|
LLVM_DEBUG(llvm::dbgs() << "\nNot fusable (not a subtensor)");
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -483,7 +492,7 @@ Optional<FusionInfo> mlir::linalg::fuseProducerOfTensor(OpBuilder &b,
|
||||||
OpBuilder::InsertionGuard g(b);
|
OpBuilder::InsertionGuard g(b);
|
||||||
b.setInsertionPoint(consumer.getOperation());
|
b.setInsertionPoint(consumer.getOperation());
|
||||||
ScopedContext scope(b, consumer.getLoc());
|
ScopedContext scope(b, consumer.getLoc());
|
||||||
LLVM_DEBUG(dbgs() << "Fuse into consumer: " << *consumer << "\n");
|
LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " << *consumer << "\n");
|
||||||
LinalgOp fusedProducer =
|
LinalgOp fusedProducer =
|
||||||
fuse(b, producerOp, producerIdx, consumer, consumerIdx);
|
fuse(b, producerOp, producerIdx, consumer, consumerIdx);
|
||||||
|
|
||||||
|
@ -501,6 +510,21 @@ Optional<FusionInfo> mlir::linalg::fuseProducerOfTensor(OpBuilder &b,
|
||||||
return FusionInfo{producerOp, fusedProducer};
|
return FusionInfo{producerOp, fusedProducer};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Prune all dimensions that are of reduction iterator type from `map`.
|
||||||
|
static AffineMap pruneReductionDimsFromMap(ArrayRef<Attribute> iteratorTypes,
|
||||||
|
AffineMap map) {
|
||||||
|
SmallVector<unsigned, 2> projectedDims;
|
||||||
|
for (auto attr : llvm::enumerate(iteratorTypes)) {
|
||||||
|
if (!isParallelIterator(attr.value()))
|
||||||
|
projectedDims.push_back(attr.index());
|
||||||
|
}
|
||||||
|
return getProjectedMap(map, projectedDims);
|
||||||
|
}
|
||||||
|
|
||||||
|
using FusableOpDependencesTy = llvm::MapVector<
|
||||||
|
Operation *,
|
||||||
|
SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 1>>;
|
||||||
|
|
||||||
/// Returns the positions of the loop in `op` that can be tiled based on the
|
/// Returns the positions of the loop in `op` that can be tiled based on the
|
||||||
/// operations that are to be fused with it. For example, in a
|
/// operations that are to be fused with it. For example, in a
|
||||||
///
|
///
|
||||||
|
@ -508,12 +532,58 @@ Optional<FusionInfo> mlir::linalg::fuseProducerOfTensor(OpBuilder &b,
|
||||||
///
|
///
|
||||||
/// if the producer of %a needs to be fused with this op, only the `i` loop of
|
/// if the producer of %a needs to be fused with this op, only the `i` loop of
|
||||||
/// the matmul can be tiled while fusing. If producer of %a, and %b are to be
|
/// the matmul can be tiled while fusing. If producer of %a, and %b are to be
|
||||||
/// fused, then no loops can be tiled while fusing.
|
/// fused, then no loops can be tiled while fusing. The conditions used are:
|
||||||
static DenseSet<unsigned> collectTileAndFuseLoops(
|
/// 1. Only parallel loops can be used for tile + fuse. Find the number of
|
||||||
LinalgOp op, ArrayRef<LinalgDependenceGraph::LinalgDependenceGraphElem>
|
/// common outer parallel loops between the op and its producers being fused.
|
||||||
fusableDependences) {
|
/// 2. Of the parallel loops only some can be fused. Only those loops can be
|
||||||
// 1. Only parallel loops can be used for tile + fuse. Find the number of
|
/// fused such where the fusable loops iteration space only touches one tile
|
||||||
// common outer parallel loops between the op and its producers being fused.
|
/// of the fused operation. This is because the producer (which is writing
|
||||||
|
/// the fused subview) has update semantics. To compute this,
|
||||||
|
/// a. Find the mapping from iterations in the consumer that write to the
|
||||||
|
/// same location as the iterations in the producer. To do so use
|
||||||
|
/// - indexing map of the fused view in the consumer : consumerIndexMap
|
||||||
|
/// - indexing map of the fused view in the producer : producerIndexMap
|
||||||
|
/// consumerLoopToProducerLoop =
|
||||||
|
/// inverse(producerIndexMap).compose(consumerIndexMap)
|
||||||
|
///
|
||||||
|
/// Since an inverse computation is needed, we need to consider the projection
|
||||||
|
/// of the producerIndexMap w.r.t the parallel loops. The actual fusable loops
|
||||||
|
/// are the dimensions of the consumerLoopToProducerLoop map that correspond to
|
||||||
|
/// parallel loops and appear in the result of the map
|
||||||
|
///
|
||||||
|
/// Example 1:
|
||||||
|
/// linalg.fill(%c, %cst)
|
||||||
|
/// linalg.matmul ins(%a, %b) outs(%c)
|
||||||
|
/// Number of parallel loops : 2
|
||||||
|
/// producerIndexMap = affine_map<(i, j) ->(i , j)>
|
||||||
|
/// consumerIndexMap = affine_map<(i, j, k) -> (i, j)>
|
||||||
|
/// consumerLoopToProducerLoop = affine_map<(i, j, k) -> (i, j)>
|
||||||
|
/// Fused dimensions : i, j
|
||||||
|
///
|
||||||
|
/// Example 2:
|
||||||
|
/// linalg.matmul ins(%a, %b) outs(%c)
|
||||||
|
/// linalg.generic {indexing_maps = [affine_map<(i, j) -> (j, i)>, ...
|
||||||
|
/// iterator_types = ["parallel", "parallel"]}
|
||||||
|
/// ins(%c) ...
|
||||||
|
///
|
||||||
|
/// Number of parallel loops = 2:
|
||||||
|
/// producerIndexMap (projected to parallel loops) =
|
||||||
|
/// affine_map<(i, j) -> (i, j)>
|
||||||
|
/// consumerLoopToProducerLoop2 = affine_map<(i, j) -> (j, i)>
|
||||||
|
/// Fused dimensions : i, j
|
||||||
|
///
|
||||||
|
/// Example 3:
|
||||||
|
/// linalg.copy(%s, %b)
|
||||||
|
/// linalg.matmul ins(%a, %b) outs(%c)
|
||||||
|
///
|
||||||
|
/// Number of parallel loops = 2
|
||||||
|
/// produceIndexMap : affine_map<(i, j) -> (i, j)>
|
||||||
|
/// consumerLoopToProduceLoops = affine_map<(i, j, k) -> (k, j)>
|
||||||
|
/// submap with only parallel loops = affine_map<(i, j) -> (j)>
|
||||||
|
/// Fused dimensions : j
|
||||||
|
static std::set<unsigned>
|
||||||
|
collectTileAndFuseLoops(LinalgOp op,
|
||||||
|
const FusableOpDependencesTy &fusableDependences) {
|
||||||
auto getNumOuterParallelLoops = [](LinalgOp linalgOp) {
|
auto getNumOuterParallelLoops = [](LinalgOp linalgOp) {
|
||||||
return linalgOp.iterator_types()
|
return linalgOp.iterator_types()
|
||||||
.getValue()
|
.getValue()
|
||||||
|
@ -524,135 +594,149 @@ static DenseSet<unsigned> collectTileAndFuseLoops(
|
||||||
.size();
|
.size();
|
||||||
};
|
};
|
||||||
|
|
||||||
|
LLVM_DEBUG({
|
||||||
|
llvm::dbgs() << "Op : ";
|
||||||
|
op.getOperation()->print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
|
||||||
|
llvm::dbgs() << "\n";
|
||||||
|
});
|
||||||
|
|
||||||
size_t numOuterParallelLoops = getNumOuterParallelLoops(op);
|
size_t numOuterParallelLoops = getNumOuterParallelLoops(op);
|
||||||
for (auto dependence : fusableDependences) {
|
for (auto dependence : fusableDependences) {
|
||||||
|
linalg::LinalgOp producer = cast<linalg::LinalgOp>(dependence.first);
|
||||||
numOuterParallelLoops =
|
numOuterParallelLoops =
|
||||||
std::min(numOuterParallelLoops, getNumOuterParallelLoops(cast<LinalgOp>(
|
std::min(numOuterParallelLoops, getNumOuterParallelLoops(producer));
|
||||||
dependence.dependentOpView.op)));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Need to compute what tiled loops can be "fused". Given the precondition
|
std::set<unsigned> fusableLoops;
|
||||||
// that all indexing map for the producer view is a projected permutation, we
|
auto range = llvm::seq<unsigned>(0, numOuterParallelLoops);
|
||||||
// can assert that the producer iterates over the dimensions of the "fused
|
fusableLoops.insert(range.begin(), range.end());
|
||||||
// view" only once. To be used a fused loop the producer should use this loop
|
|
||||||
// to access the fused view. For example, consider
|
|
||||||
//
|
|
||||||
// ```
|
|
||||||
// linalg.add ins(%a, %b) outs(%c)
|
|
||||||
// linalg.matmul ins(%d, %c) outs(%e)
|
|
||||||
// ```
|
|
||||||
//
|
|
||||||
// if `linalg.add` has the semantics of `c = a + b`, then the following
|
|
||||||
// tile+fuse code is correct.
|
|
||||||
//
|
|
||||||
// ```
|
|
||||||
// for j ... += TSj
|
|
||||||
// %sa = subview %a[0, %j][...]
|
|
||||||
// %sb = subview %b[0, %j][...]
|
|
||||||
// %sc = subview %c[0, %j][...]
|
|
||||||
// %sd = subview %d[0, 0][...]
|
|
||||||
// %se = subview %e[0, %j][...]
|
|
||||||
// linalg.add ins(%sa, %sb) outs(%sc)
|
|
||||||
// linalg.matmul ins(%sd, %sc) outs(%se)
|
|
||||||
// ```
|
|
||||||
//
|
|
||||||
// On the other hand tiling along i would be incorrect
|
|
||||||
//
|
|
||||||
// ```
|
|
||||||
// for %i .. += TSi
|
|
||||||
// %sa = subview %a[%i, 0][...]
|
|
||||||
// %sb = subview %b[%i, 0][...]
|
|
||||||
// %sc = subview %c[%i, 0][...]
|
|
||||||
// %sc2 = subview %c[0, 0][...]
|
|
||||||
// %sd = subview %d[%i, 0][...]
|
|
||||||
// %se = subview %e[%i, 0][...]
|
|
||||||
// linalg.add ins(%sa, %sb) outs(%sc)
|
|
||||||
// linalg.matmul ins(%sd, %sc2) outs(%se)
|
|
||||||
// ```
|
|
||||||
//
|
|
||||||
// The write to the subview `%sc` in `linalg.add` is performed after the read
|
|
||||||
// from it using `%sc2` violating the RAW dependence of the original code. To
|
|
||||||
// find such loops indexing map of the fused view in the consumer op is
|
|
||||||
// used. For the above example, this indexing map is
|
|
||||||
//
|
|
||||||
// affine_map<(d0, d1, d2) -> (d2, d1)>
|
|
||||||
//
|
|
||||||
// Since d0 is not in the result expressions of this map, it is not treated as
|
|
||||||
// tile + fuse loop, (but d1 is).
|
|
||||||
//
|
|
||||||
// TODO: The above is probably restrictive and there might be a generalization
|
|
||||||
// of these that might allow for more fusion opportunities. Explore based on
|
|
||||||
// needs.
|
|
||||||
SmallVector<DenseSet<unsigned>, 1> commonTilableLoops;
|
|
||||||
for (auto dependence : fusableDependences) {
|
for (auto dependence : fusableDependences) {
|
||||||
unsigned consumerIdx =
|
LLVM_DEBUG({
|
||||||
op.getIndexOfShapedOperand(dependence.indexingView).getValue();
|
llvm::dbgs() << "\t fusable :";
|
||||||
AffineMap consumerAccess = op.getIndexingMap(consumerIdx);
|
for (unsigned i : fusableLoops)
|
||||||
// Previously asserted that the consumerAccess map is a projected
|
llvm::dbgs() << " " << i;
|
||||||
// permutation, so all results are known to be AffineDimExprs. To remove
|
llvm::dbgs() << "\n";
|
||||||
// this restriction walk the expression to find which dimensions of the
|
});
|
||||||
// consumer loop appear in the `consumerAccess`.
|
linalg::LinalgOp producer = cast<linalg::LinalgOp>(dependence.first);
|
||||||
DenseSet<unsigned> positions;
|
|
||||||
for (auto expr : consumerAccess.getResults())
|
assert(!dependence.second.empty() &&
|
||||||
positions.insert(expr.cast<AffineDimExpr>().getPosition());
|
"unexpected producer but not dependences");
|
||||||
commonTilableLoops.emplace_back(std::move(positions));
|
AffineMap producerIndexingMap = producer.getIndexingMap(
|
||||||
|
dependence.second.front().dependentOpView.operandIndex);
|
||||||
|
AffineMap prunedProducerIndexingMap = pruneReductionDimsFromMap(
|
||||||
|
producer.iterator_types().getValue(), producerIndexingMap);
|
||||||
|
if (!prunedProducerIndexingMap.isPermutation())
|
||||||
|
return {};
|
||||||
|
|
||||||
|
AffineMap consumerIndexingMap = op.getIndexingMap(
|
||||||
|
dependence.second.front().indexingOpView.operandIndex);
|
||||||
|
if (consumerIndexingMap.getNumResults() !=
|
||||||
|
prunedProducerIndexingMap.getNumResults())
|
||||||
|
return {};
|
||||||
|
|
||||||
|
LLVM_DEBUG({
|
||||||
|
llvm::dbgs() << "\t producerMap : ";
|
||||||
|
producerIndexingMap.print(llvm::dbgs());
|
||||||
|
llvm::dbgs() << " pruned : ";
|
||||||
|
prunedProducerIndexingMap.print(llvm::dbgs());
|
||||||
|
llvm::dbgs() << "\n";
|
||||||
|
llvm::dbgs() << "\t consumerMap : ";
|
||||||
|
consumerIndexingMap.print(llvm::dbgs());
|
||||||
|
llvm::dbgs() << "\n";
|
||||||
|
});
|
||||||
|
|
||||||
|
AffineMap invProducerIndexMap =
|
||||||
|
inversePermutation(prunedProducerIndexingMap);
|
||||||
|
if (!invProducerIndexMap)
|
||||||
|
return {};
|
||||||
|
|
||||||
|
AffineMap consumerLoopToProducerLoop =
|
||||||
|
invProducerIndexMap.compose(consumerIndexingMap);
|
||||||
|
|
||||||
|
LLVM_DEBUG({
|
||||||
|
llvm::dbgs() << "\t consumerLoopToProducerLoop : ";
|
||||||
|
consumerLoopToProducerLoop.print(llvm::dbgs());
|
||||||
|
});
|
||||||
|
|
||||||
|
std::set<unsigned> candidates;
|
||||||
|
for (AffineExpr expr : consumerLoopToProducerLoop.getResults()) {
|
||||||
|
AffineDimExpr dimExpr = expr.dyn_cast<AffineDimExpr>();
|
||||||
|
if (!dimExpr)
|
||||||
|
continue;
|
||||||
|
unsigned position = dimExpr.getPosition();
|
||||||
|
if (fusableLoops.count(position))
|
||||||
|
candidates.insert(position);
|
||||||
|
}
|
||||||
|
LLVM_DEBUG({
|
||||||
|
llvm::dbgs() << "\t candidates :";
|
||||||
|
for (unsigned i : candidates)
|
||||||
|
llvm::dbgs() << " " << i;
|
||||||
|
llvm::dbgs() << "\n";
|
||||||
|
});
|
||||||
|
if (candidates.empty())
|
||||||
|
return {};
|
||||||
|
std::swap(candidates, fusableLoops);
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. Of the outer parallel loops, only those loops can be tiled + fused as
|
return fusableLoops;
|
||||||
// computed above for all the fused dependences can be used to tile and fuse.
|
|
||||||
DenseSet<unsigned> tilableParallelLoops;
|
|
||||||
for (auto index : llvm::seq<unsigned>(0, numOuterParallelLoops)) {
|
|
||||||
if (llvm::all_of(commonTilableLoops,
|
|
||||||
[&](const DenseSet<unsigned> &tilableLoops) {
|
|
||||||
return tilableLoops.count(index);
|
|
||||||
}))
|
|
||||||
tilableParallelLoops.insert(index);
|
|
||||||
}
|
|
||||||
return tilableParallelLoops;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Find all dependences that are to be fusable.
|
/// Find all dependences that are to be fusable.
|
||||||
static Optional<
|
static FusableOpDependencesTy
|
||||||
SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 1>>
|
|
||||||
findAllFusableDependences(LinalgOp op,
|
findAllFusableDependences(LinalgOp op,
|
||||||
const LinalgDependenceGraph &dependenceGraph,
|
const LinalgDependenceGraph &dependenceGraph,
|
||||||
const LinalgFusionOptions &fusionOptions) {
|
const LinalgFusionOptions &fusionOptions) {
|
||||||
SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 1>
|
FusableOpDependencesTy fusableDependences;
|
||||||
fusableDependences;
|
// TODO: Currently fusion would not be legal if the fusable dependence is to
|
||||||
for (auto operand : llvm::enumerate(op.getInputsAndOutputBuffers())) {
|
// the same producer but different indexing map in the consumer. Fix this, but
|
||||||
if (fusionOptions.indicesToFuse &&
|
// in the meanwhile disallow such a fusion.
|
||||||
!fusionOptions.indicesToFuse->count(operand.index()))
|
DenseMap<Operation *, AffineMap> fusedProducerIndexingMap;
|
||||||
continue;
|
for (auto operandIndex : fusionOptions.indicesToFuse) {
|
||||||
Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
|
auto fusableDependence =
|
||||||
fusableDependence =
|
findFusableProducer(op, operandIndex, dependenceGraph);
|
||||||
findFusableProducer(op, operand.index(), dependenceGraph);
|
|
||||||
if (!fusableDependence)
|
if (!fusableDependence)
|
||||||
continue;
|
return FusableOpDependencesTy{};
|
||||||
|
LinalgOp producerOp = cast<LinalgOp>(fusableDependence->dependentOpView.op);
|
||||||
|
// Do not fuse dependences that are to operations not in the same basic
|
||||||
|
// block. This avoid moving fused operations across loops that might
|
||||||
|
// themselves carry dependency making the fusion illegal.
|
||||||
|
if (producerOp.getOperation()->getBlock() !=
|
||||||
|
op.getOperation()->getBlock()) {
|
||||||
|
op.emitRemark("unhandled fusion of ops in different basic blocks");
|
||||||
|
return FusableOpDependencesTy{};
|
||||||
|
}
|
||||||
// Make sure that the indexing map of the view used for fusion in the
|
// Make sure that the indexing map of the view used for fusion in the
|
||||||
// producer is a projected permutation.
|
// producer is a projected permutation.
|
||||||
LinalgOp producerOp = cast<LinalgOp>(fusableDependence->dependentOpView.op);
|
unsigned producerIdx = fusableDependence->dependentOpView.operandIndex;
|
||||||
Value producerView = fusableDependence->dependentOpView.view;
|
AffineMap producerMap = producerOp.getIndexingMap(producerIdx);
|
||||||
unsigned producerIdx =
|
|
||||||
producerOp.getIndexOfOutputBuffer(producerView).getValue();
|
|
||||||
AffineMap producerMap = producerOp.getOutputIndexingMap(producerIdx);
|
|
||||||
if (!producerMap.isProjectedPermutation()) {
|
if (!producerMap.isProjectedPermutation()) {
|
||||||
op.emitError("unhandled non permutation indexing map for fused view in "
|
op.emitRemark("unhandled non permutation indexing map for fused view in "
|
||||||
"producer for operand at index ")
|
"producer for operand at index ")
|
||||||
<< operand.index();
|
<< operandIndex;
|
||||||
return llvm::None;
|
return FusableOpDependencesTy{};
|
||||||
}
|
}
|
||||||
Value consumerView = fusableDependence->indexingView;
|
|
||||||
unsigned consumerIdx = op.getIndexOfShapedOperand(consumerView).getValue();
|
unsigned consumerIdx = fusableDependence->indexingOpView.operandIndex;
|
||||||
if (!op.getIndexingMap(consumerIdx).isProjectedPermutation()) {
|
AffineMap consumerMap = op.getIndexingMap(consumerIdx);
|
||||||
op.emitError(
|
if (!consumerMap.isProjectedPermutation()) {
|
||||||
|
op.emitRemark(
|
||||||
"unhandled case where indexing map for fused view in the consumer is "
|
"unhandled case where indexing map for fused view in the consumer is "
|
||||||
"not a projected permuration while fusing at index ")
|
"not a projected permutation while fusing at index ")
|
||||||
<< operand.index();
|
<< operandIndex;
|
||||||
return llvm::None;
|
return FusableOpDependencesTy{};
|
||||||
}
|
}
|
||||||
fusableDependences.push_back(*fusableDependence);
|
|
||||||
if (!fusionOptions.indicesToFuse)
|
// Check if the producer is already a fusion candidate. Cannot fuse this
|
||||||
break;
|
// dependence if it has a different indexing map when used in the consumer.
|
||||||
|
if (fusedProducerIndexingMap.count(producerOp.getOperation()) &&
|
||||||
|
fusedProducerIndexingMap[producerOp.getOperation()] != consumerMap) {
|
||||||
|
op.emitRemark("unhandled fusion to the same producer but with different "
|
||||||
|
"indexing maps");
|
||||||
|
return FusableOpDependencesTy{};
|
||||||
|
}
|
||||||
|
fusedProducerIndexingMap[producerOp.getOperation()] = consumerMap;
|
||||||
|
|
||||||
|
fusableDependences[producerOp.getOperation()].push_back(*fusableDependence);
|
||||||
}
|
}
|
||||||
return fusableDependences;
|
return fusableDependences;
|
||||||
}
|
}
|
||||||
|
@ -682,13 +766,10 @@ tileAndFuseLinalgOpsImpl(PatternRewriter &rewriter, LinalgOp op,
|
||||||
ScopedContext scope(rewriter, op.getLoc());
|
ScopedContext scope(rewriter, op.getLoc());
|
||||||
|
|
||||||
// Find all the producers.
|
// Find all the producers.
|
||||||
Optional<SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 1>>
|
FusableOpDependencesTy fusableDependences =
|
||||||
fusableDependencesOpt =
|
|
||||||
findAllFusableDependences(op, dependenceGraph, fusionOptions);
|
findAllFusableDependences(op, dependenceGraph, fusionOptions);
|
||||||
if (!fusableDependencesOpt)
|
if (fusableDependences.empty())
|
||||||
return llvm::None;
|
return llvm::None;
|
||||||
ArrayRef<LinalgDependenceGraph::LinalgDependenceGraphElem> fusableDependences(
|
|
||||||
*fusableDependencesOpt);
|
|
||||||
|
|
||||||
// Enforce the convention that "tiling by zero" skips tiling a particular
|
// Enforce the convention that "tiling by zero" skips tiling a particular
|
||||||
// dimension. This convention is significantly simpler to handle instead of
|
// dimension. This convention is significantly simpler to handle instead of
|
||||||
|
@ -704,12 +785,12 @@ tileAndFuseLinalgOpsImpl(PatternRewriter &rewriter, LinalgOp op,
|
||||||
TiledAndFusedLinalgOps ret;
|
TiledAndFusedLinalgOps ret;
|
||||||
|
|
||||||
// Find the loops that can be tiled and fused.
|
// Find the loops that can be tiled and fused.
|
||||||
DenseSet<unsigned> tileFuseLoops =
|
std::set<unsigned> tileFuseLoops =
|
||||||
collectTileAndFuseLoops(op, fusableDependences);
|
collectTileAndFuseLoops(op, fusableDependences);
|
||||||
|
|
||||||
// If there are no fusable dependences or there are no tile+fusable loops,
|
// If there are no fusable dependences or there are no tile+fusable loops,
|
||||||
// just return.
|
// just return.
|
||||||
if (fusableDependences.empty() || tileFuseLoops.empty()) {
|
if (tileFuseLoops.empty()) {
|
||||||
return llvm::None;
|
return llvm::None;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -752,15 +833,15 @@ tileAndFuseLinalgOpsImpl(PatternRewriter &rewriter, LinalgOp op,
|
||||||
|
|
||||||
rewriter.setInsertionPoint(ret.op);
|
rewriter.setInsertionPoint(ret.op);
|
||||||
// Fuse the operands.
|
// Fuse the operands.
|
||||||
for (auto producer : enumerate(fusableDependences)) {
|
for (auto dependence : fusableDependences) {
|
||||||
LinalgOp producerOp = cast<LinalgOp>(producer.value().dependentOpView.op);
|
LinalgOp producerOp = cast<LinalgOp>(dependence.first);
|
||||||
unsigned producerIdx =
|
unsigned producerIdx =
|
||||||
producerOp.getIndexOfOutputBuffer(producer.value().dependentOpView.view)
|
dependence.second.front().dependentOpView.operandIndex;
|
||||||
.getValue();
|
|
||||||
unsigned consumerIdx =
|
unsigned consumerIdx =
|
||||||
op.getIndexOfShapedOperand(producer.value().indexingView).getValue();
|
dependence.second.front().indexingOpView.operandIndex;
|
||||||
LinalgOp fusedOp =
|
LinalgOp fusedOp = fuse(rewriter, producerOp,
|
||||||
fuse(rewriter, producerOp, producerIdx, ret.op, consumerIdx);
|
producerOp.getOutputIndex(producerIdx).getValue(),
|
||||||
|
ret.op, consumerIdx);
|
||||||
ret.fusedProducers.push_back(fusedOp);
|
ret.fusedProducers.push_back(fusedOp);
|
||||||
ret.originalProducers.push_back(producerOp);
|
ret.originalProducers.push_back(producerOp);
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
#include "mlir/IR/StandardTypes.h"
|
#include "mlir/IR/StandardTypes.h"
|
||||||
#include "mlir/Support/LogicalResult.h"
|
#include "mlir/Support/LogicalResult.h"
|
||||||
#include "mlir/Support/MathExtras.h"
|
#include "mlir/Support/MathExtras.h"
|
||||||
|
#include "llvm/ADT/SmallSet.h"
|
||||||
#include "llvm/ADT/StringRef.h"
|
#include "llvm/ADT/StringRef.h"
|
||||||
#include "llvm/Support/raw_ostream.h"
|
#include "llvm/Support/raw_ostream.h"
|
||||||
|
|
||||||
|
@ -450,6 +451,22 @@ AffineMap mlir::concatAffineMaps(ArrayRef<AffineMap> maps) {
|
||||||
maps.front().getContext());
|
maps.front().getContext());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
AffineMap mlir::getProjectedMap(AffineMap map,
|
||||||
|
ArrayRef<unsigned> projectedDimensions) {
|
||||||
|
DenseSet<unsigned> projectedDims(projectedDimensions.begin(),
|
||||||
|
projectedDimensions.end());
|
||||||
|
MLIRContext *context = map.getContext();
|
||||||
|
SmallVector<AffineExpr, 4> resultExprs;
|
||||||
|
for (auto dim : enumerate(llvm::seq<unsigned>(0, map.getNumDims()))) {
|
||||||
|
if (!projectedDims.count(dim.value()))
|
||||||
|
resultExprs.push_back(getAffineDimExpr(dim.index(), context));
|
||||||
|
else
|
||||||
|
resultExprs.push_back(getAffineConstantExpr(0, context));
|
||||||
|
}
|
||||||
|
return map.compose(AffineMap::get(
|
||||||
|
map.getNumDims() - projectedDimensions.size(), 0, resultExprs, context));
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// MutableAffineMap.
|
// MutableAffineMap.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
// RUN: mlir-opt %s -test-linalg-fusion-transform-patterns -canonicalize -cse -split-input-file | FileCheck %s
|
// RUN: mlir-opt %s -test-linalg-fusion-transform-patterns -canonicalize -cse -split-input-file -verify-diagnostics | FileCheck %s
|
||||||
|
|
||||||
module {
|
module {
|
||||||
func @basic_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
|
func @basic_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
|
||||||
|
@ -295,3 +295,121 @@ module {
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
// CHECK: linalg.matmul
|
// CHECK: linalg.matmul
|
||||||
// CHECK-SAME: __internal_linalg_transform__ = "after_lhs_fusion_original"
|
// CHECK-SAME: __internal_linalg_transform__ = "after_lhs_fusion_original"
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
module {
|
||||||
|
func @matmul_plus_matmul(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
|
||||||
|
%arg2: memref<?x?xf32>) {
|
||||||
|
%c0 = constant 0 : index
|
||||||
|
%c1 = constant 1 : index
|
||||||
|
%0 = dim %arg2, %c0 : memref<?x?xf32>
|
||||||
|
%1 = dim %arg2, %c1 : memref<?x?xf32>
|
||||||
|
%2 = alloc(%0, %1) : memref<?x?xf32>
|
||||||
|
linalg.matmul ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
|
||||||
|
outs(%2 : memref<?x?xf32>)
|
||||||
|
linalg.generic
|
||||||
|
{indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
|
||||||
|
affine_map<(d0, d1) -> (d0, d1)>,
|
||||||
|
affine_map<(d0, d1) -> (d0, d1)>],
|
||||||
|
iterator_types = ["parallel", "parallel"],
|
||||||
|
__internal_linalg_transform__ = "transpose_fusion"}
|
||||||
|
ins(%2, %2 : memref<?x?xf32>, memref<?x?xf32>)
|
||||||
|
outs(%arg2 : memref<?x?xf32>) {
|
||||||
|
^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32) :
|
||||||
|
%3 = addf %arg3, %arg4 : f32
|
||||||
|
linalg.yield %3 : f32
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// CHECK: func @matmul_plus_matmul
|
||||||
|
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32>
|
||||||
|
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32>
|
||||||
|
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref<?x?xf32>
|
||||||
|
// CHECK: %[[T2:.+]] = alloc(%{{.*}}, %{{.*}}) : memref<?x?xf32>
|
||||||
|
// CHECK: linalg.matmul
|
||||||
|
// CHECK-SAME: after_transpose_fusion_original
|
||||||
|
// CHECK: scf.parallel (%[[ARG3:[a-zA-Z0-9_]+]], %[[ARG4:.[a-zA-Z0-9_]+]])
|
||||||
|
// CHECK: %[[T5:.+]] = subview %[[T2]][%[[ARG3]], %[[ARG4]]]
|
||||||
|
// CHECK: %[[T6:.+]] = subview %[[ARG2]][%[[ARG3]], %[[ARG4]]]
|
||||||
|
// CHECK: %[[T8:.+]] = subview %[[ARG0]][%[[ARG3]], 0]
|
||||||
|
// CHECK: %[[T9:.+]] = subview %[[ARG1]][0, %[[ARG4]]]
|
||||||
|
// CHECK: linalg.matmul
|
||||||
|
// CHECK-SAME: after_transpose_fusion_producer
|
||||||
|
// CHECK-SAME: ins(%[[T8]], %[[T9]]
|
||||||
|
// CHECK-SAME: outs(%[[T5]]
|
||||||
|
// CHECK-NOT: linalg.matmul
|
||||||
|
// CHECK: linalg.generic
|
||||||
|
// CHECK-SAME: ins(%[[T5]], %[[T5]]
|
||||||
|
// CHECK-SAME: outs(%[[T6]]
|
||||||
|
// CHECK-SAME: after_transpose_fusion
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
module {
|
||||||
|
func @matmul_plus_transpose_matmul(%arg0: memref<?x?xf32>,
|
||||||
|
%arg1: memref<?x?xf32>,
|
||||||
|
%arg2: memref<?x?xf32>) {
|
||||||
|
%c0 = constant 0 : index
|
||||||
|
%c1 = constant 1 : index
|
||||||
|
%0 = dim %arg2, %c0 : memref<?x?xf32>
|
||||||
|
%1 = dim %arg2, %c1 : memref<?x?xf32>
|
||||||
|
%2 = alloc(%0, %1) : memref<?x?xf32>
|
||||||
|
linalg.matmul ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
|
||||||
|
outs(%2 : memref<?x?xf32>)
|
||||||
|
// expected-remark @+1 {{unhandled fusion to the same producer but with different indexing maps}}
|
||||||
|
linalg.generic
|
||||||
|
{indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
|
||||||
|
affine_map<(d0, d1) -> (d1, d0)>,
|
||||||
|
affine_map<(d0, d1) -> (d0, d1)>],
|
||||||
|
iterator_types = ["parallel", "parallel"],
|
||||||
|
__internal_linalg_transform__ = "transpose_fusion"}
|
||||||
|
ins(%2, %2 : memref<?x?xf32>, memref<?x?xf32>)
|
||||||
|
outs(%arg2 : memref<?x?xf32>) {
|
||||||
|
^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32) :
|
||||||
|
%3 = addf %arg3, %arg4 : f32
|
||||||
|
linalg.yield %3 : f32
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
#map0 = affine_map<(d0)[s0] -> (32, -d0 + s0)>
|
||||||
|
#map1 = affine_map<(d0)[s0] -> (64, -d0 + s0)>
|
||||||
|
#map2 = affine_map<(d0)[s0] -> (16, -d0 + s0)>
|
||||||
|
#map3 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
|
||||||
|
module {
|
||||||
|
func @basic_no_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
|
||||||
|
%arg2: memref<?x?xf32>) {
|
||||||
|
%c0 = constant 0 : index
|
||||||
|
%c1 = constant 1 : index
|
||||||
|
%c2 = constant 2 : index
|
||||||
|
%c32 = constant 32 : index
|
||||||
|
%c64 = constant 64 : index
|
||||||
|
%c16 = constant 16 : index
|
||||||
|
%cst = constant 0.000000e+00 : f32
|
||||||
|
linalg.fill(%arg2, %cst) : memref<?x?xf32>, f32
|
||||||
|
%0 = dim %arg0, %c0 : memref<?x?xf32>
|
||||||
|
%1 = dim %arg1, %c1 : memref<?x?xf32>
|
||||||
|
%2 = dim %arg0, %c1 : memref<?x?xf32>
|
||||||
|
scf.parallel (%arg3, %arg4) = (%c0, %c0) to (%0, %1) step (%c32, %c64) {
|
||||||
|
scf.for %arg5 = %c0 to %2 step %c16 {
|
||||||
|
%3 = affine.min #map0(%arg3)[%0]
|
||||||
|
%4 = affine.min #map1(%arg4)[%1]
|
||||||
|
%5 = affine.min #map2(%arg5)[%2]
|
||||||
|
%6 = subview %arg0[%arg3, %arg5] [%3, %5] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map3>
|
||||||
|
%7 = subview %arg1[%arg5, %arg4] [%5, %4] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map3>
|
||||||
|
%8 = subview %arg2[%arg3, %arg4] [%3, %4] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map3>
|
||||||
|
// expected-remark @+1 {{unhandled fusion of ops in different basic blocks}}
|
||||||
|
linalg.matmul {__internal_linalg_transform__ = "basic_fusion"}
|
||||||
|
ins(%6, %7 : memref<?x?xf32, #map3>, memref<?x?xf32, #map3>)
|
||||||
|
outs(%8 : memref<?x?xf32, #map3>)
|
||||||
|
}
|
||||||
|
scf.yield
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -43,7 +43,7 @@ static void fillFusionPatterns(MLIRContext *context,
|
||||||
LinalgTilingOptions()
|
LinalgTilingOptions()
|
||||||
.setTileSizes({32, 64, 16})
|
.setTileSizes({32, 64, 16})
|
||||||
.setLoopType(LinalgTilingLoopType::ParallelLoops),
|
.setLoopType(LinalgTilingLoopType::ParallelLoops),
|
||||||
LinalgFusionOptions(),
|
LinalgFusionOptions().setIndicesToFuse({2}),
|
||||||
LinalgMarker(Identifier::get("basic_fusion", context),
|
LinalgMarker(Identifier::get("basic_fusion", context),
|
||||||
Identifier::get("after_basic_fusion", context)),
|
Identifier::get("after_basic_fusion", context)),
|
||||||
LinalgMarker(ArrayRef<Identifier>(),
|
LinalgMarker(ArrayRef<Identifier>(),
|
||||||
|
@ -91,6 +91,19 @@ static void fillFusionPatterns(MLIRContext *context,
|
||||||
LinalgMarker(
|
LinalgMarker(
|
||||||
ArrayRef<Identifier>(),
|
ArrayRef<Identifier>(),
|
||||||
Identifier::get("after_two_operand_fusion_original", context)));
|
Identifier::get("after_two_operand_fusion_original", context)));
|
||||||
|
|
||||||
|
patterns.insert<LinalgTileAndFusePattern<GenericOp>>(
|
||||||
|
context, dependenceGraph,
|
||||||
|
LinalgTilingOptions().setTileSizes({32, 64}).setLoopType(
|
||||||
|
LinalgTilingLoopType::ParallelLoops),
|
||||||
|
LinalgFusionOptions().setIndicesToFuse({0, 1}),
|
||||||
|
LinalgMarker(Identifier::get("transpose_fusion", context),
|
||||||
|
Identifier::get("after_transpose_fusion", context)),
|
||||||
|
LinalgMarker(ArrayRef<Identifier>(),
|
||||||
|
Identifier::get("after_transpose_fusion_producer", context)),
|
||||||
|
LinalgMarker(
|
||||||
|
ArrayRef<Identifier>(),
|
||||||
|
Identifier::get("after_transpose_fusion_original", context)));
|
||||||
}
|
}
|
||||||
|
|
||||||
static void applyFusionPatterns(MLIRContext *context, FuncOp funcOp) {
|
static void applyFusionPatterns(MLIRContext *context, FuncOp funcOp) {
|
||||||
|
|
Loading…
Reference in New Issue