LinalgDependenceGraph: add const modifiers to accessors

MLIR const-correctness policy is to avoid having `const` on IR objects.
LinalgDependenceGraph is not an IR object but an auxiliary data structure.
Furthermore, it is not updated once constructed unlike IR objects. Add const
qualifiers to get* and find* methods of LinalgDependenceGraph since they are
not modifying the graph. This allows transformation functions that require the
dependence graph to take it by const-reference, clearly indicating that they
are not modifying it (and that the graph may have to be recomputed after the
transformation).

PiperOrigin-RevId: 277731608
This commit is contained in:
Alex Zinenko 2019-10-31 08:58:34 -07:00 committed by A. Unique TensorFlower
parent e55bd90bc7
commit f9a4d3bdb0
4 changed files with 39 additions and 37 deletions

View File

@ -66,7 +66,7 @@ public:
};
using LinalgDependences = llvm::SmallVector<LinalgDependenceGraphElem, 8>;
using DependenceGraph = DenseMap<Operation *, LinalgDependences>;
using dependence_iterator = LinalgDependences::iterator;
using dependence_iterator = LinalgDependences::const_iterator;
using dependence_range = llvm::iterator_range<dependence_iterator>;
enum DependenceType { RAR = 0, RAW, WAR, WAW, NumTypes };
@ -74,31 +74,33 @@ public:
LinalgDependenceGraph(Aliases &aliases, ArrayRef<Operation *> ops);
/// Returns the X such that op -> X is a dependence of type dt.
dependence_range getDependencesFrom(Operation *src, DependenceType dt);
dependence_range getDependencesFrom(LinalgOp src, DependenceType dt);
dependence_range getDependencesFrom(Operation *src, DependenceType dt) const;
dependence_range getDependencesFrom(LinalgOp src, DependenceType dt) const;
/// Returns the X such that X -> op is a dependence of type dt.
dependence_range getDependencesInto(Operation *dst, DependenceType dt);
dependence_range getDependencesInto(LinalgOp dst, DependenceType dt);
dependence_range getDependencesInto(Operation *dst, DependenceType dt) const;
dependence_range getDependencesInto(LinalgOp dst, DependenceType dt) const;
/// Returns the operations that are interleaved between `srcLinalgOp` and
/// `dstLinalgOp` and that are involved in any RAW, WAR or WAW dependence
/// relation with `srcLinalgOp`, on any view.
/// Any such operation prevents reordering.
SmallVector<Operation *, 8> findCoveringDependences(LinalgOp srcLinalgOp,
LinalgOp dstLinalgOp);
SmallVector<Operation *, 8>
findCoveringDependences(LinalgOp srcLinalgOp, LinalgOp dstLinalgOp) const;
/// Returns the operations that are interleaved between `srcLinalgOp` and
/// `dstLinalgOp` and that are involved in a RAR or RAW with `srcLinalgOp`.
/// Dependences are restricted to views aliasing `view`.
SmallVector<Operation *, 8>
findCoveringReads(LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value *view);
SmallVector<Operation *, 8> findCoveringReads(LinalgOp srcLinalgOp,
LinalgOp dstLinalgOp,
Value *view) const;
/// Returns the operations that are interleaved between `srcLinalgOp` and
/// `dstLinalgOp` and that are involved in a WAR or WAW with `srcLinalgOp`.
/// Dependences are restricted to views aliasing `view`.
SmallVector<Operation *, 8>
findCoveringWrites(LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value *view);
SmallVector<Operation *, 8> findCoveringWrites(LinalgOp srcLinalgOp,
LinalgOp dstLinalgOp,
Value *view) const;
private:
// Keep dependences in both directions, this is not just a performance gain
@ -125,7 +127,7 @@ private:
SmallVector<Operation *, 8>
findOperationsWithCoveringDependences(LinalgOp srcLinalgOp,
LinalgOp dstLinalgOp, Value *view,
ArrayRef<DependenceType> types);
ArrayRef<DependenceType> types) const;
Aliases &aliases;
SmallVector<Operation *, 8> linalgOps;

View File

@ -88,7 +88,7 @@ struct FusionInfo {
/// method is called.
Optional<FusionInfo> fuseProducerOf(OpBuilder &b, LinalgOp consumer,
unsigned consumerIdx,
LinalgDependenceGraph &graph,
const LinalgDependenceGraph &graph,
OperationFolder *folder = nullptr);
/// Returns the linearized list of all view dimensions in a linalgOp. Applying

View File

@ -113,28 +113,32 @@ void LinalgDependenceGraph::addDependenceElem(DependenceType dt,
LinalgDependenceGraph::dependence_range
LinalgDependenceGraph::getDependencesFrom(
LinalgOp src, LinalgDependenceGraph::DependenceType dt) {
LinalgOp src, LinalgDependenceGraph::DependenceType dt) const {
return getDependencesFrom(src.getOperation(), dt);
}
LinalgDependenceGraph::dependence_range
LinalgDependenceGraph::getDependencesFrom(
Operation *src, LinalgDependenceGraph::DependenceType dt) {
auto &vec = dependencesFromGraphs[dt][src];
return llvm::make_range(vec.begin(), vec.end());
Operation *src, LinalgDependenceGraph::DependenceType dt) const {
auto iter = dependencesFromGraphs[dt].find(src);
if (iter == dependencesFromGraphs[dt].end())
return llvm::make_range(nullptr, nullptr);
return llvm::make_range(iter->second.begin(), iter->second.end());
}
LinalgDependenceGraph::dependence_range
LinalgDependenceGraph::getDependencesInto(
LinalgOp dst, LinalgDependenceGraph::DependenceType dt) {
LinalgOp dst, LinalgDependenceGraph::DependenceType dt) const {
return getDependencesInto(dst.getOperation(), dt);
}
LinalgDependenceGraph::dependence_range
LinalgDependenceGraph::getDependencesInto(
Operation *dst, LinalgDependenceGraph::DependenceType dt) {
auto &vec = dependencesIntoGraphs[dt][dst];
return llvm::make_range(vec.begin(), vec.end());
Operation *dst, LinalgDependenceGraph::DependenceType dt) const {
auto iter = dependencesIntoGraphs[dt].find(dst);
if (iter == dependencesIntoGraphs[dt].end())
return llvm::make_range(nullptr, nullptr);
return llvm::make_range(iter->second.begin(), iter->second.end());
}
void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) {
@ -178,23 +182,21 @@ void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) {
SmallVector<Operation *, 8>
LinalgDependenceGraph::findCoveringDependences(LinalgOp srcLinalgOp,
LinalgOp dstLinalgOp) {
LinalgOp dstLinalgOp) const {
return findOperationsWithCoveringDependences(
srcLinalgOp, dstLinalgOp, nullptr,
{DependenceType::WAW, DependenceType::WAR, DependenceType::RAW});
}
SmallVector<Operation *, 8>
LinalgDependenceGraph::findCoveringWrites(LinalgOp srcLinalgOp,
LinalgOp dstLinalgOp, Value *view) {
SmallVector<Operation *, 8> LinalgDependenceGraph::findCoveringWrites(
LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value *view) const {
return findOperationsWithCoveringDependences(
srcLinalgOp, dstLinalgOp, view,
{DependenceType::WAW, DependenceType::WAR});
}
SmallVector<Operation *, 8>
LinalgDependenceGraph::findCoveringReads(LinalgOp srcLinalgOp,
LinalgOp dstLinalgOp, Value *view) {
SmallVector<Operation *, 8> LinalgDependenceGraph::findCoveringReads(
LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value *view) const {
return findOperationsWithCoveringDependences(
srcLinalgOp, dstLinalgOp, view,
{DependenceType::RAR, DependenceType::RAW});
@ -203,11 +205,11 @@ LinalgDependenceGraph::findCoveringReads(LinalgOp srcLinalgOp,
SmallVector<Operation *, 8>
LinalgDependenceGraph::findOperationsWithCoveringDependences(
LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value *view,
ArrayRef<DependenceType> types) {
ArrayRef<DependenceType> types) const {
auto *src = srcLinalgOp.getOperation();
auto *dst = dstLinalgOp.getOperation();
auto srcPos = linalgOpPositions[src];
auto dstPos = linalgOpPositions[dst];
auto srcPos = linalgOpPositions.lookup(src);
auto dstPos = linalgOpPositions.lookup(dst);
assert(srcPos < dstPos && "expected dst after src in IR traversal order");
SmallVector<Operation *, 8> res;
@ -216,7 +218,7 @@ LinalgDependenceGraph::findOperationsWithCoveringDependences(
// TODO(ntv) we are not considering paths yet, just interleaved positions.
for (auto dt : types) {
for (auto dependence : getDependencesFrom(src, dt)) {
auto interimPos = linalgOpPositions[dependence.dependentOpView.op];
auto interimPos = linalgOpPositions.lookup(dependence.dependentOpView.op);
// Skip if not interleaved.
if (interimPos >= dstPos || interimPos <= srcPos)
continue;

View File

@ -227,11 +227,9 @@ static bool isStructurallyFusableProducer(LinalgOp producer, Value *readView,
}
// Only consider RAW atm.
Optional<FusionInfo> mlir::linalg::fuseProducerOf(OpBuilder &b,
LinalgOp consumer,
unsigned consumerIdx,
LinalgDependenceGraph &graph,
OperationFolder *folder) {
Optional<FusionInfo> mlir::linalg::fuseProducerOf(
OpBuilder &b, LinalgOp consumer, unsigned consumerIdx,
const LinalgDependenceGraph &graph, OperationFolder *folder) {
LLVM_DEBUG(dbgs() << "\nStart examining consumer: "
<< *consumer.getOperation());
for (auto dependence : graph.getDependencesInto(