[mlir][Linalg] Change LinalgDependenceGraph to use LinalgOp.

Using LinalgOp will reduce the repeated conversion from Operation <->
LinalgOp.

Differential Revision: https://reviews.llvm.org/D91101
This commit is contained in:
MaheshRavishankar 2020-11-13 12:21:43 -08:00
parent a8db144169
commit bf3861bf71
3 changed files with 12 additions and 12 deletions

View File

@ -9,6 +9,7 @@
#ifndef MLIR_DIALECT_LINALG_ANALYSIS_DEPENDENCEANALYSIS_H_
#define MLIR_DIALECT_LINALG_ANALYSIS_DEPENDENCEANALYSIS_H_
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/OpDefinition.h"
@ -67,7 +68,7 @@ public:
// Builds a linalg dependence graph for the ops of type LinalgOp under `f`.
static LinalgDependenceGraph buildDependenceGraph(Aliases &aliases, FuncOp f);
LinalgDependenceGraph(Aliases &aliases, ArrayRef<Operation *> ops);
LinalgDependenceGraph(Aliases &aliases, ArrayRef<LinalgOp> ops);
/// Returns the X such that op -> X is a dependence of type dt.
dependence_range getDependencesFrom(Operation *src, DependenceType dt) const;
@ -168,7 +169,7 @@ private:
ArrayRef<DependenceType> types) const;
Aliases &aliases;
SmallVector<Operation *, 8> linalgOps;
SmallVector<LinalgOp, 8> linalgOps;
DenseMap<Operation *, unsigned> linalgOpPositions;
};
} // namespace linalg

View File

@ -86,21 +86,21 @@ StringRef LinalgDependenceGraph::getDependenceTypeStr(DependenceType depType) {
LinalgDependenceGraph
LinalgDependenceGraph::buildDependenceGraph(Aliases &aliases, FuncOp f) {
SmallVector<Operation *, 8> linalgOps;
SmallVector<LinalgOp, 8> linalgOps;
f.walk([&](LinalgOp op) { linalgOps.push_back(op); });
return LinalgDependenceGraph(aliases, linalgOps);
}
LinalgDependenceGraph::LinalgDependenceGraph(Aliases &aliases,
ArrayRef<Operation *> ops)
ArrayRef<LinalgOp> ops)
: aliases(aliases), linalgOps(ops.begin(), ops.end()) {
for (auto en : llvm::enumerate(linalgOps)) {
assert(isa<LinalgOp>(en.value()) && "Expected value for LinalgOp");
linalgOpPositions.insert(std::make_pair(en.value(), en.index()));
linalgOpPositions.insert(
std::make_pair(en.value().getOperation(), en.index()));
}
for (unsigned i = 0, e = ops.size(); i < e; ++i) {
for (unsigned j = i + 1; j < e; ++j) {
addDependencesBetween(cast<LinalgOp>(ops[i]), cast<LinalgOp>(ops[j]));
addDependencesBetween(ops[i], ops[j]);
}
}
}

View File

@ -124,7 +124,7 @@ static LogicalResult fuseLinalgOpsGreedily(FuncOp f) {
DenseSet<Operation *> eraseSet;
// Save original Linalg ops, we only want to make a pass over those.
SmallVector<Operation *, 8> linalgOps;
SmallVector<LinalgOp, 8> linalgOps;
f.walk([&](LinalgOp op) {
// TODO: support multi-results.
if (op.getOperation()->getNumResults() <= 1)
@ -133,8 +133,7 @@ static LogicalResult fuseLinalgOpsGreedily(FuncOp f) {
// Tile and Fuse for tensors inputs (TODO: all tensor operands).
bool changed = false;
for (auto *op : llvm::reverse(linalgOps)) {
LinalgOp linalgOp = cast<LinalgOp>(op);
for (LinalgOp linalgOp : llvm::reverse(linalgOps)) {
for (auto en : llvm::enumerate(linalgOp.getShapedOperands())) {
if (en.value().getType().isa<MemRefType>()) {
// TODO: LinalgDependenceGraph should be able to update itself.
@ -142,7 +141,7 @@ static LogicalResult fuseLinalgOpsGreedily(FuncOp f) {
// removed.
linalg::Aliases aliases;
linalg::LinalgDependenceGraph graph(aliases, linalgOps);
if (auto info = fuseProducerOfBuffer(b, op, en.index(), graph)) {
if (auto info = fuseProducerOfBuffer(b, linalgOp, en.index(), graph)) {
auto *originalOp = info->originalProducer.getOperation();
eraseSet.insert(originalOp);
auto *originalOpInLinalgOpsVector =
@ -155,7 +154,7 @@ static LogicalResult fuseLinalgOpsGreedily(FuncOp f) {
// Tile and Fuse tensor input (TODO: init_tensors too).
if (en.index() >= linalgOp.getNumInputs())
continue;
if (auto info = fuseProducerOfTensor(b, op, en.index())) {
if (auto info = fuseProducerOfTensor(b, linalgOp, en.index())) {
auto *originalOp = info->originalProducer.getOperation();
auto *originalOpInLinalgOpsVector =
std::find(linalgOps.begin(), linalgOps.end(), originalOp);