diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 2e0673795f30..2e6a85926079 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -521,7 +521,7 @@ struct LinalgCopyVTWForwardingPattern LogicalResult applyStagedPatterns( Operation *op, ArrayRef stage1Patterns, const OwningRewritePatternList &stage2Patterns, - llvm::function_ref stage3Lambda = nullptr); + function_ref stage3Lambda = nullptr); } // namespace linalg } // namespace mlir diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index 6b124e0ecdfa..8178f71ec43d 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -394,7 +394,7 @@ public: /// type `T`. template OwningRewritePatternList(T &&t) { - patterns.emplace_back(std::make_unique(t)); + patterns.emplace_back(std::make_unique(std::forward(t))); } PatternListT::iterator begin() { return patterns.begin(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 527d162298bf..76e118e482f0 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -37,6 +37,8 @@ using namespace mlir::linalg; using llvm::dbgs; +#define DEBUG_TYPE "linalg-transforms" + //===----------------------------------------------------------------------===// // Transformations exposed as rewrite patterns. //===----------------------------------------------------------------------===// @@ -45,13 +47,13 @@ const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker = "__internal_linalg_transform__"; mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef matchDisjunction, - llvm::Optional replacement) + Optional replacement) : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), replacement(replacement) {} mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef matchDisjunction, StringRef replacement) - : LinalgMarker(matchDisjunction, llvm::Optional{replacement}) {} + : LinalgMarker(matchDisjunction, Optional{replacement}) {} LogicalResult mlir::linalg::LinalgMarker::checkAndNotify(PatternRewriter &rewriter, @@ -72,7 +74,7 @@ mlir::linalg::LinalgMarker::checkAndNotify(PatternRewriter &rewriter, // 3. Has no marker but was expecting a marker. return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { diag << " does not have any marker from list: "; - llvm::interleaveComma(matchDisjunction, diag); + interleaveComma(matchDisjunction, diag); }); } @@ -84,7 +86,7 @@ mlir::linalg::LinalgMarker::checkAndNotify(PatternRewriter &rewriter, // 5. Fail to match. return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { diag << " does not have any marker from list: "; - llvm::interleaveComma(matchDisjunction, diag); + interleaveComma(matchDisjunction, diag); }); } @@ -105,7 +107,7 @@ mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef ts) { OpBuilder::InsertionGuard guard(b); b.setInsertionPointToStart( &op->getParentOfType().getBody().front()); - return llvm::to_vector<4>(llvm::map_range(tileSizes, [&](int64_t s) { + return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { Value v = b.create(op->getLoc(), s); return v; })); @@ -217,19 +219,33 @@ LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite( LogicalResult mlir::linalg::applyStagedPatterns( Operation *op, ArrayRef stage1Patterns, const OwningRewritePatternList &stage2Patterns, - llvm::function_ref stage3Lambda) { + function_ref stage3Lambda) { + unsigned iteration = 0; + (void)iteration; + StringRef dbgPref = "\n[" DEBUG_TYPE "]: "; + (void)dbgPref; for (const auto &patterns : stage1Patterns) { if (!applyPatternsAndFoldGreedily(op, patterns)) { - llvm::dbgs() << "Underlying first stage rewrite did not converge"; + dbgs() << "Underlying first stage rewrite did not converge"; return failure(); } + LLVM_DEBUG(dbgs() + << dbgPref << "After 1st stage, iter: " << ++iteration << "\n" + << *op); if (!applyPatternsAndFoldGreedily(op, stage2Patterns)) { - llvm::dbgs() << "Underlying second stage rewrite did not converge"; + LLVM_DEBUG(dbgs() + << dbgPref << "Underlying 2nd stage rewrite did not converge"); return failure(); } + LLVM_DEBUG(dbgs() + << dbgPref << "After 2nd stage, iter : " << iteration << "\n" + << *op); if (stage3Lambda) { if (failed(stage3Lambda(op))) return failure(); + LLVM_DEBUG(dbgs() + << dbgPref << "After 3rd stage, iter : " << iteration << "\n" + << *op); } } return success();