diff --git a/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h b/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h index 372f6c4e01a1..f27b929f2fc0 100644 --- a/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h +++ b/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_LINALG_ANALYSIS_DEPENDENCEANALYSIS_H_ #define MLIR_DIALECT_LINALG_ANALYSIS_DEPENDENCEANALYSIS_H_ +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/OpDefinition.h" @@ -67,7 +68,7 @@ public: // Builds a linalg dependence graph for the ops of type LinalgOp under `f`. static LinalgDependenceGraph buildDependenceGraph(Aliases &aliases, FuncOp f); - LinalgDependenceGraph(Aliases &aliases, ArrayRef ops); + LinalgDependenceGraph(Aliases &aliases, ArrayRef ops); /// Returns the X such that op -> X is a dependence of type dt. dependence_range getDependencesFrom(Operation *src, DependenceType dt) const; @@ -168,7 +169,7 @@ private: ArrayRef types) const; Aliases &aliases; - SmallVector linalgOps; + SmallVector linalgOps; DenseMap linalgOpPositions; }; } // namespace linalg diff --git a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp index 01e167d1f0aa..96da933888f2 100644 --- a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp +++ b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp @@ -86,21 +86,21 @@ StringRef LinalgDependenceGraph::getDependenceTypeStr(DependenceType depType) { LinalgDependenceGraph LinalgDependenceGraph::buildDependenceGraph(Aliases &aliases, FuncOp f) { - SmallVector linalgOps; + SmallVector linalgOps; f.walk([&](LinalgOp op) { linalgOps.push_back(op); }); return LinalgDependenceGraph(aliases, linalgOps); } LinalgDependenceGraph::LinalgDependenceGraph(Aliases &aliases, - ArrayRef ops) + ArrayRef ops) : aliases(aliases), linalgOps(ops.begin(), ops.end()) { for (auto en : llvm::enumerate(linalgOps)) { - assert(isa(en.value()) && "Expected value for LinalgOp"); - linalgOpPositions.insert(std::make_pair(en.value(), en.index())); + linalgOpPositions.insert( + std::make_pair(en.value().getOperation(), en.index())); } for (unsigned i = 0, e = ops.size(); i < e; ++i) { for (unsigned j = i + 1; j < e; ++j) { - addDependencesBetween(cast(ops[i]), cast(ops[j])); + addDependencesBetween(ops[i], ops[j]); } } } diff --git a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp index e6e150b7bf47..eb9e3a533138 100644 --- a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp +++ b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp @@ -124,7 +124,7 @@ static LogicalResult fuseLinalgOpsGreedily(FuncOp f) { DenseSet eraseSet; // Save original Linalg ops, we only want to make a pass over those. - SmallVector linalgOps; + SmallVector linalgOps; f.walk([&](LinalgOp op) { // TODO: support multi-results. if (op.getOperation()->getNumResults() <= 1) @@ -133,8 +133,7 @@ static LogicalResult fuseLinalgOpsGreedily(FuncOp f) { // Tile and Fuse for tensors inputs (TODO: all tensor operands). bool changed = false; - for (auto *op : llvm::reverse(linalgOps)) { - LinalgOp linalgOp = cast(op); + for (LinalgOp linalgOp : llvm::reverse(linalgOps)) { for (auto en : llvm::enumerate(linalgOp.getShapedOperands())) { if (en.value().getType().isa()) { // TODO: LinalgDependenceGraph should be able to update itself. @@ -142,7 +141,7 @@ static LogicalResult fuseLinalgOpsGreedily(FuncOp f) { // removed. linalg::Aliases aliases; linalg::LinalgDependenceGraph graph(aliases, linalgOps); - if (auto info = fuseProducerOfBuffer(b, op, en.index(), graph)) { + if (auto info = fuseProducerOfBuffer(b, linalgOp, en.index(), graph)) { auto *originalOp = info->originalProducer.getOperation(); eraseSet.insert(originalOp); auto *originalOpInLinalgOpsVector = @@ -155,7 +154,7 @@ static LogicalResult fuseLinalgOpsGreedily(FuncOp f) { // Tile and Fuse tensor input (TODO: init_tensors too). if (en.index() >= linalgOp.getNumInputs()) continue; - if (auto info = fuseProducerOfTensor(b, op, en.index())) { + if (auto info = fuseProducerOfTensor(b, linalgOp, en.index())) { auto *originalOp = info->originalProducer.getOperation(); auto *originalOpInLinalgOpsVector = std::find(linalgOps.begin(), linalgOps.end(), originalOp);