forked from OSchip/llvm-project
[mlir][Linalg] Change LinalgDependenceGraph to use LinalgOp.
Using LinalgOp will reduce the repeated conversion from Operation <-> LinalgOp. Differential Revision: https://reviews.llvm.org/D91101
This commit is contained in:
parent
a8db144169
commit
bf3861bf71
|
@ -9,6 +9,7 @@
|
|||
#ifndef MLIR_DIALECT_LINALG_ANALYSIS_DEPENDENCEANALYSIS_H_
|
||||
#define MLIR_DIALECT_LINALG_ANALYSIS_DEPENDENCEANALYSIS_H_
|
||||
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
|
||||
|
@ -67,7 +68,7 @@ public:
|
|||
|
||||
// Builds a linalg dependence graph for the ops of type LinalgOp under `f`.
|
||||
static LinalgDependenceGraph buildDependenceGraph(Aliases &aliases, FuncOp f);
|
||||
LinalgDependenceGraph(Aliases &aliases, ArrayRef<Operation *> ops);
|
||||
LinalgDependenceGraph(Aliases &aliases, ArrayRef<LinalgOp> ops);
|
||||
|
||||
/// Returns the X such that op -> X is a dependence of type dt.
|
||||
dependence_range getDependencesFrom(Operation *src, DependenceType dt) const;
|
||||
|
@ -168,7 +169,7 @@ private:
|
|||
ArrayRef<DependenceType> types) const;
|
||||
|
||||
Aliases &aliases;
|
||||
SmallVector<Operation *, 8> linalgOps;
|
||||
SmallVector<LinalgOp, 8> linalgOps;
|
||||
DenseMap<Operation *, unsigned> linalgOpPositions;
|
||||
};
|
||||
} // namespace linalg
|
||||
|
|
|
@ -86,21 +86,21 @@ StringRef LinalgDependenceGraph::getDependenceTypeStr(DependenceType depType) {
|
|||
|
||||
LinalgDependenceGraph
|
||||
LinalgDependenceGraph::buildDependenceGraph(Aliases &aliases, FuncOp f) {
|
||||
SmallVector<Operation *, 8> linalgOps;
|
||||
SmallVector<LinalgOp, 8> linalgOps;
|
||||
f.walk([&](LinalgOp op) { linalgOps.push_back(op); });
|
||||
return LinalgDependenceGraph(aliases, linalgOps);
|
||||
}
|
||||
|
||||
LinalgDependenceGraph::LinalgDependenceGraph(Aliases &aliases,
|
||||
ArrayRef<Operation *> ops)
|
||||
ArrayRef<LinalgOp> ops)
|
||||
: aliases(aliases), linalgOps(ops.begin(), ops.end()) {
|
||||
for (auto en : llvm::enumerate(linalgOps)) {
|
||||
assert(isa<LinalgOp>(en.value()) && "Expected value for LinalgOp");
|
||||
linalgOpPositions.insert(std::make_pair(en.value(), en.index()));
|
||||
linalgOpPositions.insert(
|
||||
std::make_pair(en.value().getOperation(), en.index()));
|
||||
}
|
||||
for (unsigned i = 0, e = ops.size(); i < e; ++i) {
|
||||
for (unsigned j = i + 1; j < e; ++j) {
|
||||
addDependencesBetween(cast<LinalgOp>(ops[i]), cast<LinalgOp>(ops[j]));
|
||||
addDependencesBetween(ops[i], ops[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -124,7 +124,7 @@ static LogicalResult fuseLinalgOpsGreedily(FuncOp f) {
|
|||
DenseSet<Operation *> eraseSet;
|
||||
|
||||
// Save original Linalg ops, we only want to make a pass over those.
|
||||
SmallVector<Operation *, 8> linalgOps;
|
||||
SmallVector<LinalgOp, 8> linalgOps;
|
||||
f.walk([&](LinalgOp op) {
|
||||
// TODO: support multi-results.
|
||||
if (op.getOperation()->getNumResults() <= 1)
|
||||
|
@ -133,8 +133,7 @@ static LogicalResult fuseLinalgOpsGreedily(FuncOp f) {
|
|||
|
||||
// Tile and Fuse for tensors inputs (TODO: all tensor operands).
|
||||
bool changed = false;
|
||||
for (auto *op : llvm::reverse(linalgOps)) {
|
||||
LinalgOp linalgOp = cast<LinalgOp>(op);
|
||||
for (LinalgOp linalgOp : llvm::reverse(linalgOps)) {
|
||||
for (auto en : llvm::enumerate(linalgOp.getShapedOperands())) {
|
||||
if (en.value().getType().isa<MemRefType>()) {
|
||||
// TODO: LinalgDependenceGraph should be able to update itself.
|
||||
|
@ -142,7 +141,7 @@ static LogicalResult fuseLinalgOpsGreedily(FuncOp f) {
|
|||
// removed.
|
||||
linalg::Aliases aliases;
|
||||
linalg::LinalgDependenceGraph graph(aliases, linalgOps);
|
||||
if (auto info = fuseProducerOfBuffer(b, op, en.index(), graph)) {
|
||||
if (auto info = fuseProducerOfBuffer(b, linalgOp, en.index(), graph)) {
|
||||
auto *originalOp = info->originalProducer.getOperation();
|
||||
eraseSet.insert(originalOp);
|
||||
auto *originalOpInLinalgOpsVector =
|
||||
|
@ -155,7 +154,7 @@ static LogicalResult fuseLinalgOpsGreedily(FuncOp f) {
|
|||
// Tile and Fuse tensor input (TODO: init_tensors too).
|
||||
if (en.index() >= linalgOp.getNumInputs())
|
||||
continue;
|
||||
if (auto info = fuseProducerOfTensor(b, op, en.index())) {
|
||||
if (auto info = fuseProducerOfTensor(b, linalgOp, en.index())) {
|
||||
auto *originalOp = info->originalProducer.getOperation();
|
||||
auto *originalOpInLinalgOpsVector =
|
||||
std::find(linalgOps.begin(), linalgOps.end(), originalOp);
|
||||
|
|
Loading…
Reference in New Issue