[mlir] NFC - Add debug information for Linalg transformations.

Address post-commit review of https://reviews.llvm.org/D79518
This commit is contained in:
Nicolas Vasilache 2020-05-29 18:07:39 -04:00
parent 8f8029b458
commit 91beb5176b
3 changed files with 26 additions and 10 deletions

View File

@ -521,7 +521,7 @@ struct LinalgCopyVTWForwardingPattern
LogicalResult applyStagedPatterns(
Operation *op, ArrayRef<OwningRewritePatternList> stage1Patterns,
const OwningRewritePatternList &stage2Patterns,
llvm::function_ref<LogicalResult(Operation *)> stage3Lambda = nullptr);
function_ref<LogicalResult(Operation *)> stage3Lambda = nullptr);
} // namespace linalg
} // namespace mlir

View File

@ -394,7 +394,7 @@ public:
/// type `T`.
template <typename T>
OwningRewritePatternList(T &&t) {
patterns.emplace_back(std::make_unique<T>(t));
patterns.emplace_back(std::make_unique<T>(std::forward<T>(t)));
}
PatternListT::iterator begin() { return patterns.begin(); }

View File

@ -37,6 +37,8 @@ using namespace mlir::linalg;
using llvm::dbgs;
#define DEBUG_TYPE "linalg-transforms"
//===----------------------------------------------------------------------===//
// Transformations exposed as rewrite patterns.
//===----------------------------------------------------------------------===//
@ -45,13 +47,13 @@ const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker =
"__internal_linalg_transform__";
mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef<StringRef> matchDisjunction,
llvm::Optional<StringRef> replacement)
Optional<StringRef> replacement)
: matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
replacement(replacement) {}
mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef<StringRef> matchDisjunction,
StringRef replacement)
: LinalgMarker(matchDisjunction, llvm::Optional<StringRef>{replacement}) {}
: LinalgMarker(matchDisjunction, Optional<StringRef>{replacement}) {}
LogicalResult
mlir::linalg::LinalgMarker::checkAndNotify(PatternRewriter &rewriter,
@ -72,7 +74,7 @@ mlir::linalg::LinalgMarker::checkAndNotify(PatternRewriter &rewriter,
// 3. Has no marker but was expecting a marker.
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
diag << " does not have any marker from list: ";
llvm::interleaveComma(matchDisjunction, diag);
interleaveComma(matchDisjunction, diag);
});
}
@ -84,7 +86,7 @@ mlir::linalg::LinalgMarker::checkAndNotify(PatternRewriter &rewriter,
// 5. Fail to match.
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
diag << " does not have any marker from list: ";
llvm::interleaveComma(matchDisjunction, diag);
interleaveComma(matchDisjunction, diag);
});
}
@ -105,7 +107,7 @@ mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
OpBuilder::InsertionGuard guard(b);
b.setInsertionPointToStart(
&op->getParentOfType<FuncOp>().getBody().front());
return llvm::to_vector<4>(llvm::map_range(tileSizes, [&](int64_t s) {
return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
Value v = b.create<ConstantIndexOp>(op->getLoc(), s);
return v;
}));
@ -217,19 +219,33 @@ LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite(
LogicalResult mlir::linalg::applyStagedPatterns(
Operation *op, ArrayRef<OwningRewritePatternList> stage1Patterns,
const OwningRewritePatternList &stage2Patterns,
llvm::function_ref<LogicalResult(Operation *)> stage3Lambda) {
function_ref<LogicalResult(Operation *)> stage3Lambda) {
unsigned iteration = 0;
(void)iteration;
StringRef dbgPref = "\n[" DEBUG_TYPE "]: ";
(void)dbgPref;
for (const auto &patterns : stage1Patterns) {
if (!applyPatternsAndFoldGreedily(op, patterns)) {
llvm::dbgs() << "Underlying first stage rewrite did not converge";
dbgs() << "Underlying first stage rewrite did not converge";
return failure();
}
LLVM_DEBUG(dbgs()
<< dbgPref << "After 1st stage, iter: " << ++iteration << "\n"
<< *op);
if (!applyPatternsAndFoldGreedily(op, stage2Patterns)) {
llvm::dbgs() << "Underlying second stage rewrite did not converge";
LLVM_DEBUG(dbgs()
<< dbgPref << "Underlying 2nd stage rewrite did not converge");
return failure();
}
LLVM_DEBUG(dbgs()
<< dbgPref << "After 2nd stage, iter : " << iteration << "\n"
<< *op);
if (stage3Lambda) {
if (failed(stage3Lambda(op)))
return failure();
LLVM_DEBUG(dbgs()
<< dbgPref << "After 3rd stage, iter : " << iteration << "\n"
<< *op);
}
}
return success();