[mlir][Linalg] Change LinalgDependenceGraph to use LinalgOp.

Using LinalgOp will reduce the repeated conversion from Operation <->
LinalgOp.

Differential Revision: https://reviews.llvm.org/D91101
This commit is contained in:
MaheshRavishankar 2020-11-13 12:21:43 -08:00
parent a8db144169
commit bf3861bf71
3 changed files with 12 additions and 12 deletions

View File

@ -9,6 +9,7 @@
#ifndef MLIR_DIALECT_LINALG_ANALYSIS_DEPENDENCEANALYSIS_H_ #ifndef MLIR_DIALECT_LINALG_ANALYSIS_DEPENDENCEANALYSIS_H_
#define MLIR_DIALECT_LINALG_ANALYSIS_DEPENDENCEANALYSIS_H_ #define MLIR_DIALECT_LINALG_ANALYSIS_DEPENDENCEANALYSIS_H_
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpDefinition.h"
@ -67,7 +68,7 @@ public:
// Builds a linalg dependence graph for the ops of type LinalgOp under `f`. // Builds a linalg dependence graph for the ops of type LinalgOp under `f`.
static LinalgDependenceGraph buildDependenceGraph(Aliases &aliases, FuncOp f); static LinalgDependenceGraph buildDependenceGraph(Aliases &aliases, FuncOp f);
LinalgDependenceGraph(Aliases &aliases, ArrayRef<Operation *> ops); LinalgDependenceGraph(Aliases &aliases, ArrayRef<LinalgOp> ops);
/// Returns the X such that op -> X is a dependence of type dt. /// Returns the X such that op -> X is a dependence of type dt.
dependence_range getDependencesFrom(Operation *src, DependenceType dt) const; dependence_range getDependencesFrom(Operation *src, DependenceType dt) const;
@ -168,7 +169,7 @@ private:
ArrayRef<DependenceType> types) const; ArrayRef<DependenceType> types) const;
Aliases &aliases; Aliases &aliases;
SmallVector<Operation *, 8> linalgOps; SmallVector<LinalgOp, 8> linalgOps;
DenseMap<Operation *, unsigned> linalgOpPositions; DenseMap<Operation *, unsigned> linalgOpPositions;
}; };
} // namespace linalg } // namespace linalg

View File

@ -86,21 +86,21 @@ StringRef LinalgDependenceGraph::getDependenceTypeStr(DependenceType depType) {
LinalgDependenceGraph LinalgDependenceGraph
LinalgDependenceGraph::buildDependenceGraph(Aliases &aliases, FuncOp f) { LinalgDependenceGraph::buildDependenceGraph(Aliases &aliases, FuncOp f) {
SmallVector<Operation *, 8> linalgOps; SmallVector<LinalgOp, 8> linalgOps;
f.walk([&](LinalgOp op) { linalgOps.push_back(op); }); f.walk([&](LinalgOp op) { linalgOps.push_back(op); });
return LinalgDependenceGraph(aliases, linalgOps); return LinalgDependenceGraph(aliases, linalgOps);
} }
LinalgDependenceGraph::LinalgDependenceGraph(Aliases &aliases, LinalgDependenceGraph::LinalgDependenceGraph(Aliases &aliases,
ArrayRef<Operation *> ops) ArrayRef<LinalgOp> ops)
: aliases(aliases), linalgOps(ops.begin(), ops.end()) { : aliases(aliases), linalgOps(ops.begin(), ops.end()) {
for (auto en : llvm::enumerate(linalgOps)) { for (auto en : llvm::enumerate(linalgOps)) {
assert(isa<LinalgOp>(en.value()) && "Expected value for LinalgOp"); linalgOpPositions.insert(
linalgOpPositions.insert(std::make_pair(en.value(), en.index())); std::make_pair(en.value().getOperation(), en.index()));
} }
for (unsigned i = 0, e = ops.size(); i < e; ++i) { for (unsigned i = 0, e = ops.size(); i < e; ++i) {
for (unsigned j = i + 1; j < e; ++j) { for (unsigned j = i + 1; j < e; ++j) {
addDependencesBetween(cast<LinalgOp>(ops[i]), cast<LinalgOp>(ops[j])); addDependencesBetween(ops[i], ops[j]);
} }
} }
} }

View File

@ -124,7 +124,7 @@ static LogicalResult fuseLinalgOpsGreedily(FuncOp f) {
DenseSet<Operation *> eraseSet; DenseSet<Operation *> eraseSet;
// Save original Linalg ops, we only want to make a pass over those. // Save original Linalg ops, we only want to make a pass over those.
SmallVector<Operation *, 8> linalgOps; SmallVector<LinalgOp, 8> linalgOps;
f.walk([&](LinalgOp op) { f.walk([&](LinalgOp op) {
// TODO: support multi-results. // TODO: support multi-results.
if (op.getOperation()->getNumResults() <= 1) if (op.getOperation()->getNumResults() <= 1)
@ -133,8 +133,7 @@ static LogicalResult fuseLinalgOpsGreedily(FuncOp f) {
// Tile and Fuse for tensors inputs (TODO: all tensor operands). // Tile and Fuse for tensors inputs (TODO: all tensor operands).
bool changed = false; bool changed = false;
for (auto *op : llvm::reverse(linalgOps)) { for (LinalgOp linalgOp : llvm::reverse(linalgOps)) {
LinalgOp linalgOp = cast<LinalgOp>(op);
for (auto en : llvm::enumerate(linalgOp.getShapedOperands())) { for (auto en : llvm::enumerate(linalgOp.getShapedOperands())) {
if (en.value().getType().isa<MemRefType>()) { if (en.value().getType().isa<MemRefType>()) {
// TODO: LinalgDependenceGraph should be able to update itself. // TODO: LinalgDependenceGraph should be able to update itself.
@ -142,7 +141,7 @@ static LogicalResult fuseLinalgOpsGreedily(FuncOp f) {
// removed. // removed.
linalg::Aliases aliases; linalg::Aliases aliases;
linalg::LinalgDependenceGraph graph(aliases, linalgOps); linalg::LinalgDependenceGraph graph(aliases, linalgOps);
if (auto info = fuseProducerOfBuffer(b, op, en.index(), graph)) { if (auto info = fuseProducerOfBuffer(b, linalgOp, en.index(), graph)) {
auto *originalOp = info->originalProducer.getOperation(); auto *originalOp = info->originalProducer.getOperation();
eraseSet.insert(originalOp); eraseSet.insert(originalOp);
auto *originalOpInLinalgOpsVector = auto *originalOpInLinalgOpsVector =
@ -155,7 +154,7 @@ static LogicalResult fuseLinalgOpsGreedily(FuncOp f) {
// Tile and Fuse tensor input (TODO: init_tensors too). // Tile and Fuse tensor input (TODO: init_tensors too).
if (en.index() >= linalgOp.getNumInputs()) if (en.index() >= linalgOp.getNumInputs())
continue; continue;
if (auto info = fuseProducerOfTensor(b, op, en.index())) { if (auto info = fuseProducerOfTensor(b, linalgOp, en.index())) {
auto *originalOp = info->originalProducer.getOperation(); auto *originalOp = info->originalProducer.getOperation();
auto *originalOpInLinalgOpsVector = auto *originalOpInLinalgOpsVector =
std::find(linalgOps.begin(), linalgOps.end(), originalOp); std::find(linalgOps.begin(), linalgOps.end(), originalOp);