2019-05-15 06:03:48 +08:00
|
|
|
//===- NestedMatcher.cpp - NestedMatcher Impl ----------------------------===//
|
2019-01-26 22:59:23 +08:00
|
|
|
//
|
2020-01-26 11:58:30 +08:00
|
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
2019-12-24 01:35:36 +08:00
|
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
2019-01-26 22:59:23 +08:00
|
|
|
//
|
2019-12-24 01:35:36 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2019-01-26 22:59:23 +08:00
|
|
|
|
|
|
|
#include "mlir/Analysis/NestedMatcher.h"
|
2020-03-21 05:18:47 +08:00
|
|
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
2020-02-22 03:54:49 +08:00
|
|
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
2019-01-26 22:59:23 +08:00
|
|
|
|
|
|
|
#include "llvm/ADT/ArrayRef.h"
|
2019-01-31 23:16:29 +08:00
|
|
|
#include "llvm/ADT/STLExtras.h"
|
2019-01-26 22:59:23 +08:00
|
|
|
#include "llvm/Support/Allocator.h"
|
|
|
|
#include "llvm/Support/raw_ostream.h"
|
|
|
|
|
|
|
|
using namespace mlir;
|
|
|
|
|
|
|
|
llvm::BumpPtrAllocator *&NestedMatch::allocator() {
|
2019-01-31 23:16:29 +08:00
|
|
|
thread_local llvm::BumpPtrAllocator *allocator = nullptr;
|
2019-01-26 22:59:23 +08:00
|
|
|
return allocator;
|
|
|
|
}
|
|
|
|
|
2019-03-27 23:55:17 +08:00
|
|
|
NestedMatch NestedMatch::build(Operation *operation,
|
2019-01-31 23:16:29 +08:00
|
|
|
ArrayRef<NestedMatch> nestedMatches) {
|
2019-01-26 22:59:23 +08:00
|
|
|
auto *result = allocator()->Allocate<NestedMatch>();
|
2019-01-31 23:16:29 +08:00
|
|
|
auto *children = allocator()->Allocate<NestedMatch>(nestedMatches.size());
|
|
|
|
std::uninitialized_copy(nestedMatches.begin(), nestedMatches.end(), children);
|
|
|
|
new (result) NestedMatch();
|
2019-03-27 23:55:17 +08:00
|
|
|
result->matchedOperation = operation;
|
2019-01-31 23:16:29 +08:00
|
|
|
result->matchedChildren =
|
|
|
|
ArrayRef<NestedMatch>(children, nestedMatches.size());
|
2019-01-26 22:59:23 +08:00
|
|
|
return *result;
|
|
|
|
}
|
|
|
|
|
2019-01-31 23:16:29 +08:00
|
|
|
llvm::BumpPtrAllocator *&NestedPattern::allocator() {
|
|
|
|
thread_local llvm::BumpPtrAllocator *allocator = nullptr;
|
|
|
|
return allocator;
|
2019-01-26 22:59:23 +08:00
|
|
|
}
|
|
|
|
|
2021-03-12 19:19:47 +08:00
|
|
|
void NestedPattern::copyNestedToThis(ArrayRef<NestedPattern> nested) {
|
|
|
|
if (nested.empty())
|
|
|
|
return;
|
|
|
|
|
|
|
|
auto *newNested = allocator()->Allocate<NestedPattern>(nested.size());
|
|
|
|
std::uninitialized_copy(nested.begin(), nested.end(), newNested);
|
|
|
|
nestedPatterns = ArrayRef<NestedPattern>(newNested, nested.size());
|
|
|
|
}
|
|
|
|
|
|
|
|
void NestedPattern::freeNested() {
|
|
|
|
for (const auto &p : nestedPatterns)
|
|
|
|
p.~NestedPattern();
|
|
|
|
}
|
|
|
|
|
2019-02-04 02:03:46 +08:00
|
|
|
NestedPattern::NestedPattern(ArrayRef<NestedPattern> nested,
|
2019-01-31 23:16:29 +08:00
|
|
|
FilterFunctionType filter)
|
2019-02-04 02:03:46 +08:00
|
|
|
: nestedPatterns(), filter(filter), skip(nullptr) {
|
2021-03-12 19:19:47 +08:00
|
|
|
copyNestedToThis(nested);
|
|
|
|
}
|
|
|
|
|
|
|
|
NestedPattern::NestedPattern(const NestedPattern &other)
|
|
|
|
: nestedPatterns(), filter(other.filter), skip(other.skip) {
|
|
|
|
copyNestedToThis(other.nestedPatterns);
|
|
|
|
}
|
|
|
|
|
|
|
|
NestedPattern &NestedPattern::operator=(const NestedPattern &other) {
|
|
|
|
freeNested();
|
|
|
|
filter = other.filter;
|
|
|
|
skip = other.skip;
|
|
|
|
copyNestedToThis(other.nestedPatterns);
|
|
|
|
return *this;
|
2019-01-26 22:59:23 +08:00
|
|
|
}
|
|
|
|
|
2019-01-31 23:16:29 +08:00
|
|
|
unsigned NestedPattern::getDepth() const {
|
|
|
|
if (nestedPatterns.empty()) {
|
2019-01-26 22:59:23 +08:00
|
|
|
return 1;
|
|
|
|
}
|
|
|
|
unsigned depth = 0;
|
2019-01-31 23:16:29 +08:00
|
|
|
for (auto &c : nestedPatterns) {
|
2019-01-26 22:59:23 +08:00
|
|
|
depth = std::max(depth, c.getDepth());
|
|
|
|
}
|
|
|
|
return depth + 1;
|
|
|
|
}
|
|
|
|
|
2019-03-27 23:55:17 +08:00
|
|
|
/// Matches a single operation in the following way:
|
|
|
|
/// 1. checks the kind of operation against the matcher, if different then
|
2019-01-26 22:59:23 +08:00
|
|
|
/// there is no match;
|
2019-03-27 23:55:17 +08:00
|
|
|
/// 2. calls the customizable filter function to refine the single operation
|
2019-01-26 22:59:23 +08:00
|
|
|
/// match with extra semantic constraints;
|
2019-10-20 15:11:03 +08:00
|
|
|
/// 3. if all is good, recursively matches the nested patterns;
|
2019-03-27 23:55:17 +08:00
|
|
|
/// 4. if all nested match then the single operation matches too and is
|
2019-01-26 22:59:23 +08:00
|
|
|
/// appended to the list of matches;
|
2020-07-07 16:35:23 +08:00
|
|
|
/// 5. TODO: Optionally applies actions (lambda), in which case we will want
|
|
|
|
/// to traverse in post-order DFS to avoid invalidating iterators.
|
2019-03-27 23:55:17 +08:00
|
|
|
void NestedPattern::matchOne(Operation *op,
|
2019-01-31 23:16:29 +08:00
|
|
|
SmallVectorImpl<NestedMatch> *matches) {
|
2019-03-27 23:55:17 +08:00
|
|
|
if (skip == op) {
|
2019-01-26 22:59:23 +08:00
|
|
|
return;
|
|
|
|
}
|
|
|
|
// Local custom filter function
|
2019-03-27 23:55:17 +08:00
|
|
|
if (!filter(*op)) {
|
2019-01-26 22:59:23 +08:00
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
2019-01-31 23:16:29 +08:00
|
|
|
if (nestedPatterns.empty()) {
|
|
|
|
SmallVector<NestedMatch, 8> nestedMatches;
|
2019-03-27 23:55:17 +08:00
|
|
|
matches->push_back(NestedMatch::build(op, nestedMatches));
|
2019-01-31 23:16:29 +08:00
|
|
|
return;
|
|
|
|
}
|
|
|
|
// Take a copy of each nested pattern so we can match it.
|
|
|
|
for (auto nestedPattern : nestedPatterns) {
|
|
|
|
SmallVector<NestedMatch, 8> nestedMatches;
|
2019-01-26 22:59:23 +08:00
|
|
|
// Skip elem in the walk immediately following. Without this we would
|
2019-04-05 02:13:02 +08:00
|
|
|
// essentially need to reimplement walk here.
|
2019-03-27 23:55:17 +08:00
|
|
|
nestedPattern.skip = op;
|
|
|
|
nestedPattern.match(op, &nestedMatches);
|
2019-01-31 23:16:29 +08:00
|
|
|
// If we could not match even one of the specified nestedPattern, early exit
|
|
|
|
// as this whole branch is not a match.
|
|
|
|
if (nestedMatches.empty()) {
|
2019-01-26 22:59:23 +08:00
|
|
|
return;
|
|
|
|
}
|
2019-03-27 23:55:17 +08:00
|
|
|
matches->push_back(NestedMatch::build(op, nestedMatches));
|
2019-01-26 22:59:23 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2019-05-12 09:59:54 +08:00
|
|
|
static bool isAffineForOp(Operation &op) { return isa<AffineForOp>(op); }
|
2019-02-02 08:42:18 +08:00
|
|
|
|
2019-05-12 09:59:54 +08:00
|
|
|
static bool isAffineIfOp(Operation &op) { return isa<AffineIfOp>(op); }
|
2019-01-29 13:23:53 +08:00
|
|
|
|
2019-01-26 22:59:23 +08:00
|
|
|
namespace mlir {
|
|
|
|
namespace matcher {
|
|
|
|
|
|
|
|
NestedPattern Op(FilterFunctionType filter) {
|
2019-02-04 02:03:46 +08:00
|
|
|
return NestedPattern({}, filter);
|
2019-01-26 22:59:23 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
NestedPattern If(NestedPattern child) {
|
2019-02-04 02:03:46 +08:00
|
|
|
return NestedPattern(child, isAffineIfOp);
|
2019-01-26 22:59:23 +08:00
|
|
|
}
|
|
|
|
NestedPattern If(FilterFunctionType filter, NestedPattern child) {
|
2019-03-27 23:55:17 +08:00
|
|
|
return NestedPattern(child, [filter](Operation &op) {
|
|
|
|
return isAffineIfOp(op) && filter(op);
|
2019-02-04 02:03:46 +08:00
|
|
|
});
|
2019-01-26 22:59:23 +08:00
|
|
|
}
|
|
|
|
NestedPattern If(ArrayRef<NestedPattern> nested) {
|
2019-02-04 02:03:46 +08:00
|
|
|
return NestedPattern(nested, isAffineIfOp);
|
2019-01-26 22:59:23 +08:00
|
|
|
}
|
|
|
|
NestedPattern If(FilterFunctionType filter, ArrayRef<NestedPattern> nested) {
|
2019-03-27 23:55:17 +08:00
|
|
|
return NestedPattern(nested, [filter](Operation &op) {
|
|
|
|
return isAffineIfOp(op) && filter(op);
|
2019-02-04 02:03:46 +08:00
|
|
|
});
|
2019-01-26 22:59:23 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
NestedPattern For(NestedPattern child) {
|
2019-02-04 02:03:46 +08:00
|
|
|
return NestedPattern(child, isAffineForOp);
|
2019-01-26 22:59:23 +08:00
|
|
|
}
|
|
|
|
NestedPattern For(FilterFunctionType filter, NestedPattern child) {
|
2019-03-27 23:55:17 +08:00
|
|
|
return NestedPattern(
|
|
|
|
child, [=](Operation &op) { return isAffineForOp(op) && filter(op); });
|
2019-01-26 22:59:23 +08:00
|
|
|
}
|
|
|
|
NestedPattern For(ArrayRef<NestedPattern> nested) {
|
2019-02-04 02:03:46 +08:00
|
|
|
return NestedPattern(nested, isAffineForOp);
|
2019-01-26 22:59:23 +08:00
|
|
|
}
|
|
|
|
NestedPattern For(FilterFunctionType filter, ArrayRef<NestedPattern> nested) {
|
2019-03-27 23:55:17 +08:00
|
|
|
return NestedPattern(
|
|
|
|
nested, [=](Operation &op) { return isAffineForOp(op) && filter(op); });
|
2019-01-26 22:59:23 +08:00
|
|
|
}
|
|
|
|
|
2019-03-27 23:55:17 +08:00
|
|
|
bool isLoadOrStore(Operation &op) {
|
2020-06-25 02:49:30 +08:00
|
|
|
return isa<AffineLoadOp, AffineStoreOp>(op);
|
2019-05-11 04:54:11 +08:00
|
|
|
}
|
2019-01-26 22:59:23 +08:00
|
|
|
|
2021-12-08 02:27:58 +08:00
|
|
|
} // namespace matcher
|
|
|
|
} // namespace mlir
|