From f9a4d3bdb024a918fd5eab7d59176dbc2ab08e80 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Thu, 31 Oct 2019 08:58:34 -0700 Subject: [PATCH] LinalgDependenceGraph: add const modifiers to accessors MLIR const-correctness policy is to avoid having `const` on IR objects. LinalgDependenceGraph is not an IR object but an auxiliary data structure. Furthermore, it is not updated once constructed unlike IR objects. Add const qualifiers to get* and find* methods of LinalgDependenceGraph since they are not modifying the graph. This allows transformation functions that require the dependence graph to take it by const-reference, clearly indicating that they are not modifying it (and that the graph may have to be recomputed after the transformation). PiperOrigin-RevId: 277731608 --- .../Linalg/Analysis/DependenceAnalysis.h | 26 ++++++------ .../include/mlir/Dialect/Linalg/Utils/Utils.h | 2 +- .../Linalg/Analysis/DependenceAnalysis.cpp | 40 ++++++++++--------- mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp | 8 ++-- 4 files changed, 39 insertions(+), 37 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h b/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h index 2367363b9b41..65cb2e63dc02 100644 --- a/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h +++ b/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h @@ -66,7 +66,7 @@ public: }; using LinalgDependences = llvm::SmallVector; using DependenceGraph = DenseMap; - using dependence_iterator = LinalgDependences::iterator; + using dependence_iterator = LinalgDependences::const_iterator; using dependence_range = llvm::iterator_range; enum DependenceType { RAR = 0, RAW, WAR, WAW, NumTypes }; @@ -74,31 +74,33 @@ public: LinalgDependenceGraph(Aliases &aliases, ArrayRef ops); /// Returns the X such that op -> X is a dependence of type dt. - dependence_range getDependencesFrom(Operation *src, DependenceType dt); - dependence_range getDependencesFrom(LinalgOp src, DependenceType dt); + dependence_range getDependencesFrom(Operation *src, DependenceType dt) const; + dependence_range getDependencesFrom(LinalgOp src, DependenceType dt) const; /// Returns the X such that X -> op is a dependence of type dt. - dependence_range getDependencesInto(Operation *dst, DependenceType dt); - dependence_range getDependencesInto(LinalgOp dst, DependenceType dt); + dependence_range getDependencesInto(Operation *dst, DependenceType dt) const; + dependence_range getDependencesInto(LinalgOp dst, DependenceType dt) const; /// Returns the operations that are interleaved between `srcLinalgOp` and /// `dstLinalgOp` and that are involved in any RAW, WAR or WAW dependence /// relation with `srcLinalgOp`, on any view. /// Any such operation prevents reordering. - SmallVector findCoveringDependences(LinalgOp srcLinalgOp, - LinalgOp dstLinalgOp); + SmallVector + findCoveringDependences(LinalgOp srcLinalgOp, LinalgOp dstLinalgOp) const; /// Returns the operations that are interleaved between `srcLinalgOp` and /// `dstLinalgOp` and that are involved in a RAR or RAW with `srcLinalgOp`. /// Dependences are restricted to views aliasing `view`. - SmallVector - findCoveringReads(LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value *view); + SmallVector findCoveringReads(LinalgOp srcLinalgOp, + LinalgOp dstLinalgOp, + Value *view) const; /// Returns the operations that are interleaved between `srcLinalgOp` and /// `dstLinalgOp` and that are involved in a WAR or WAW with `srcLinalgOp`. /// Dependences are restricted to views aliasing `view`. - SmallVector - findCoveringWrites(LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value *view); + SmallVector findCoveringWrites(LinalgOp srcLinalgOp, + LinalgOp dstLinalgOp, + Value *view) const; private: // Keep dependences in both directions, this is not just a performance gain @@ -125,7 +127,7 @@ private: SmallVector findOperationsWithCoveringDependences(LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value *view, - ArrayRef types); + ArrayRef types) const; Aliases &aliases; SmallVector linalgOps; diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h index 2b2fdfb5efc3..0bfcdea20077 100644 --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -88,7 +88,7 @@ struct FusionInfo { /// method is called. Optional fuseProducerOf(OpBuilder &b, LinalgOp consumer, unsigned consumerIdx, - LinalgDependenceGraph &graph, + const LinalgDependenceGraph &graph, OperationFolder *folder = nullptr); /// Returns the linearized list of all view dimensions in a linalgOp. Applying diff --git a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp index 14db309589ad..3a90e61ed10e 100644 --- a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp +++ b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp @@ -113,28 +113,32 @@ void LinalgDependenceGraph::addDependenceElem(DependenceType dt, LinalgDependenceGraph::dependence_range LinalgDependenceGraph::getDependencesFrom( - LinalgOp src, LinalgDependenceGraph::DependenceType dt) { + LinalgOp src, LinalgDependenceGraph::DependenceType dt) const { return getDependencesFrom(src.getOperation(), dt); } LinalgDependenceGraph::dependence_range LinalgDependenceGraph::getDependencesFrom( - Operation *src, LinalgDependenceGraph::DependenceType dt) { - auto &vec = dependencesFromGraphs[dt][src]; - return llvm::make_range(vec.begin(), vec.end()); + Operation *src, LinalgDependenceGraph::DependenceType dt) const { + auto iter = dependencesFromGraphs[dt].find(src); + if (iter == dependencesFromGraphs[dt].end()) + return llvm::make_range(nullptr, nullptr); + return llvm::make_range(iter->second.begin(), iter->second.end()); } LinalgDependenceGraph::dependence_range LinalgDependenceGraph::getDependencesInto( - LinalgOp dst, LinalgDependenceGraph::DependenceType dt) { + LinalgOp dst, LinalgDependenceGraph::DependenceType dt) const { return getDependencesInto(dst.getOperation(), dt); } LinalgDependenceGraph::dependence_range LinalgDependenceGraph::getDependencesInto( - Operation *dst, LinalgDependenceGraph::DependenceType dt) { - auto &vec = dependencesIntoGraphs[dt][dst]; - return llvm::make_range(vec.begin(), vec.end()); + Operation *dst, LinalgDependenceGraph::DependenceType dt) const { + auto iter = dependencesIntoGraphs[dt].find(dst); + if (iter == dependencesIntoGraphs[dt].end()) + return llvm::make_range(nullptr, nullptr); + return llvm::make_range(iter->second.begin(), iter->second.end()); } void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) { @@ -178,23 +182,21 @@ void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) { SmallVector LinalgDependenceGraph::findCoveringDependences(LinalgOp srcLinalgOp, - LinalgOp dstLinalgOp) { + LinalgOp dstLinalgOp) const { return findOperationsWithCoveringDependences( srcLinalgOp, dstLinalgOp, nullptr, {DependenceType::WAW, DependenceType::WAR, DependenceType::RAW}); } -SmallVector -LinalgDependenceGraph::findCoveringWrites(LinalgOp srcLinalgOp, - LinalgOp dstLinalgOp, Value *view) { +SmallVector LinalgDependenceGraph::findCoveringWrites( + LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value *view) const { return findOperationsWithCoveringDependences( srcLinalgOp, dstLinalgOp, view, {DependenceType::WAW, DependenceType::WAR}); } -SmallVector -LinalgDependenceGraph::findCoveringReads(LinalgOp srcLinalgOp, - LinalgOp dstLinalgOp, Value *view) { +SmallVector LinalgDependenceGraph::findCoveringReads( + LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value *view) const { return findOperationsWithCoveringDependences( srcLinalgOp, dstLinalgOp, view, {DependenceType::RAR, DependenceType::RAW}); @@ -203,11 +205,11 @@ LinalgDependenceGraph::findCoveringReads(LinalgOp srcLinalgOp, SmallVector LinalgDependenceGraph::findOperationsWithCoveringDependences( LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value *view, - ArrayRef types) { + ArrayRef types) const { auto *src = srcLinalgOp.getOperation(); auto *dst = dstLinalgOp.getOperation(); - auto srcPos = linalgOpPositions[src]; - auto dstPos = linalgOpPositions[dst]; + auto srcPos = linalgOpPositions.lookup(src); + auto dstPos = linalgOpPositions.lookup(dst); assert(srcPos < dstPos && "expected dst after src in IR traversal order"); SmallVector res; @@ -216,7 +218,7 @@ LinalgDependenceGraph::findOperationsWithCoveringDependences( // TODO(ntv) we are not considering paths yet, just interleaved positions. for (auto dt : types) { for (auto dependence : getDependencesFrom(src, dt)) { - auto interimPos = linalgOpPositions[dependence.dependentOpView.op]; + auto interimPos = linalgOpPositions.lookup(dependence.dependentOpView.op); // Skip if not interleaved. if (interimPos >= dstPos || interimPos <= srcPos) continue; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp index ebdb32b7010b..8e7370af7a3d 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -227,11 +227,9 @@ static bool isStructurallyFusableProducer(LinalgOp producer, Value *readView, } // Only consider RAW atm. -Optional mlir::linalg::fuseProducerOf(OpBuilder &b, - LinalgOp consumer, - unsigned consumerIdx, - LinalgDependenceGraph &graph, - OperationFolder *folder) { +Optional mlir::linalg::fuseProducerOf( + OpBuilder &b, LinalgOp consumer, unsigned consumerIdx, + const LinalgDependenceGraph &graph, OperationFolder *folder) { LLVM_DEBUG(dbgs() << "\nStart examining consumer: " << *consumer.getOperation()); for (auto dependence : graph.getDependencesInto(