forked from OSchip/llvm-project
[mlir] NFC - Add debug information for Linalg transformations.
Address post-commit review of https://reviews.llvm.org/D79518
This commit is contained in:
parent
8f8029b458
commit
91beb5176b
|
@ -521,7 +521,7 @@ struct LinalgCopyVTWForwardingPattern
|
|||
LogicalResult applyStagedPatterns(
|
||||
Operation *op, ArrayRef<OwningRewritePatternList> stage1Patterns,
|
||||
const OwningRewritePatternList &stage2Patterns,
|
||||
llvm::function_ref<LogicalResult(Operation *)> stage3Lambda = nullptr);
|
||||
function_ref<LogicalResult(Operation *)> stage3Lambda = nullptr);
|
||||
} // namespace linalg
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -394,7 +394,7 @@ public:
|
|||
/// type `T`.
|
||||
template <typename T>
|
||||
OwningRewritePatternList(T &&t) {
|
||||
patterns.emplace_back(std::make_unique<T>(t));
|
||||
patterns.emplace_back(std::make_unique<T>(std::forward<T>(t)));
|
||||
}
|
||||
|
||||
PatternListT::iterator begin() { return patterns.begin(); }
|
||||
|
|
|
@ -37,6 +37,8 @@ using namespace mlir::linalg;
|
|||
|
||||
using llvm::dbgs;
|
||||
|
||||
#define DEBUG_TYPE "linalg-transforms"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Transformations exposed as rewrite patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -45,13 +47,13 @@ const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker =
|
|||
"__internal_linalg_transform__";
|
||||
|
||||
mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef<StringRef> matchDisjunction,
|
||||
llvm::Optional<StringRef> replacement)
|
||||
Optional<StringRef> replacement)
|
||||
: matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
|
||||
replacement(replacement) {}
|
||||
|
||||
mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef<StringRef> matchDisjunction,
|
||||
StringRef replacement)
|
||||
: LinalgMarker(matchDisjunction, llvm::Optional<StringRef>{replacement}) {}
|
||||
: LinalgMarker(matchDisjunction, Optional<StringRef>{replacement}) {}
|
||||
|
||||
LogicalResult
|
||||
mlir::linalg::LinalgMarker::checkAndNotify(PatternRewriter &rewriter,
|
||||
|
@ -72,7 +74,7 @@ mlir::linalg::LinalgMarker::checkAndNotify(PatternRewriter &rewriter,
|
|||
// 3. Has no marker but was expecting a marker.
|
||||
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
|
||||
diag << " does not have any marker from list: ";
|
||||
llvm::interleaveComma(matchDisjunction, diag);
|
||||
interleaveComma(matchDisjunction, diag);
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -84,7 +86,7 @@ mlir::linalg::LinalgMarker::checkAndNotify(PatternRewriter &rewriter,
|
|||
// 5. Fail to match.
|
||||
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
|
||||
diag << " does not have any marker from list: ";
|
||||
llvm::interleaveComma(matchDisjunction, diag);
|
||||
interleaveComma(matchDisjunction, diag);
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -105,7 +107,7 @@ mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
|
|||
OpBuilder::InsertionGuard guard(b);
|
||||
b.setInsertionPointToStart(
|
||||
&op->getParentOfType<FuncOp>().getBody().front());
|
||||
return llvm::to_vector<4>(llvm::map_range(tileSizes, [&](int64_t s) {
|
||||
return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
|
||||
Value v = b.create<ConstantIndexOp>(op->getLoc(), s);
|
||||
return v;
|
||||
}));
|
||||
|
@ -217,19 +219,33 @@ LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite(
|
|||
LogicalResult mlir::linalg::applyStagedPatterns(
|
||||
Operation *op, ArrayRef<OwningRewritePatternList> stage1Patterns,
|
||||
const OwningRewritePatternList &stage2Patterns,
|
||||
llvm::function_ref<LogicalResult(Operation *)> stage3Lambda) {
|
||||
function_ref<LogicalResult(Operation *)> stage3Lambda) {
|
||||
unsigned iteration = 0;
|
||||
(void)iteration;
|
||||
StringRef dbgPref = "\n[" DEBUG_TYPE "]: ";
|
||||
(void)dbgPref;
|
||||
for (const auto &patterns : stage1Patterns) {
|
||||
if (!applyPatternsAndFoldGreedily(op, patterns)) {
|
||||
llvm::dbgs() << "Underlying first stage rewrite did not converge";
|
||||
dbgs() << "Underlying first stage rewrite did not converge";
|
||||
return failure();
|
||||
}
|
||||
LLVM_DEBUG(dbgs()
|
||||
<< dbgPref << "After 1st stage, iter: " << ++iteration << "\n"
|
||||
<< *op);
|
||||
if (!applyPatternsAndFoldGreedily(op, stage2Patterns)) {
|
||||
llvm::dbgs() << "Underlying second stage rewrite did not converge";
|
||||
LLVM_DEBUG(dbgs()
|
||||
<< dbgPref << "Underlying 2nd stage rewrite did not converge");
|
||||
return failure();
|
||||
}
|
||||
LLVM_DEBUG(dbgs()
|
||||
<< dbgPref << "After 2nd stage, iter : " << iteration << "\n"
|
||||
<< *op);
|
||||
if (stage3Lambda) {
|
||||
if (failed(stage3Lambda(op)))
|
||||
return failure();
|
||||
LLVM_DEBUG(dbgs()
|
||||
<< dbgPref << "After 3rd stage, iter : " << iteration << "\n"
|
||||
<< *op);
|
||||
}
|
||||
}
|
||||
return success();
|
||||
|
|
Loading…
Reference in New Issue