llvm-project/mlir/lib/Analysis/MLFunctionMatcher.cpp

261 lines
9.1 KiB
C++
Raw Normal View History

[MLIR] Basic infrastructure for vectorization test This CL implements a very simple loop vectorization **test** and the basic infrastructure to support it. The test simply consists in: 1. matching the loops in the MLFunction and all the Load/Store operations nested under the loop; 2. testing whether all the Load/Store are contiguous along the innermost memory dimension along that particular loop. If any reference is non-contiguous (i.e. the ForStmt SSAValue appears in the expression), then the loop is not-vectorizable. The simple test above can gradually be extended with more interesting behaviors to account for the fact that a layout permutation may exist that enables contiguity etc. All these will come in due time but it is worthwhile noting that the test already supports detection of outer-vetorizable loops. In implementing this test, I also added a recursive MLFunctionMatcher and some sugar that can capture patterns such as `auto gemmLike = Doall(Doall(Red(LoadStore())))` and allows iterating on the matched IR structures. For now it just uses in order traversal but post-order DFS will be useful in the future once IR rewrites start occuring. One may note that the memory management design decision follows a different pattern from MLIR. After evaluating different designs and how they quickly increase cognitive overhead, I decided to opt for the simplest solution in my view: a class-wide (threadsafe) RAII context. This way, a pass that needs MLFunctionMatcher can just have its own locally scoped BumpPtrAllocator and everything is cleaned up when the pass is destroyed. If passes are expected to have a longer lifetime, then the contexts can easily be scoped inside the runOnMLFunction call and storage lifetime reduced. Lastly, whatever the scope of threading (module, function, pass), this is expected to also be future-proof wrt concurrency (but this is a detail atm). PiperOrigin-RevId: 217622889
2018-10-18 09:01:44 +08:00
//===- MLFunctionMatcher.cpp - MLFunctionMatcher Impl ----------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "mlir/Analysis/MLFunctionMatcher.h"
#include "mlir/StandardOps/StandardOps.h"
#include "llvm/Support/Allocator.h"
namespace mlir {
/// Underlying storage for MLFunctionMatches.
struct MLFunctionMatchesStorage {
MLFunctionMatchesStorage(MLFunctionMatches::EntryType e) : matches({e}) {}
SmallVector<MLFunctionMatches::EntryType, 8> matches;
};
/// Underlying storage for MLFunctionMatcher.
struct MLFunctionMatcherStorage {
MLFunctionMatcherStorage(Statement::Kind k,
MutableArrayRef<MLFunctionMatcher> c,
FilterFunctionType filter)
: kind(k), childrenMLFunctionMatchers(c.begin(), c.end()),
filter(filter) {}
Statement::Kind kind;
SmallVector<MLFunctionMatcher, 4> childrenMLFunctionMatchers;
FilterFunctionType filter;
};
} // end namespace mlir
using namespace mlir;
llvm::BumpPtrAllocator *&MLFunctionMatches::allocator() {
static thread_local llvm::BumpPtrAllocator *allocator = nullptr;
return allocator;
}
void MLFunctionMatches::append(Statement *stmt, MLFunctionMatches children) {
if (!storage) {
storage = allocator()->Allocate<MLFunctionMatchesStorage>();
new (storage) MLFunctionMatchesStorage(std::make_pair(stmt, children));
} else {
storage->matches.push_back(std::make_pair(stmt, children));
}
}
MLFunctionMatches::iterator MLFunctionMatches::begin() {
return storage->matches.begin();
}
MLFunctionMatches::iterator MLFunctionMatches::end() {
return storage->matches.end();
}
/// Return the combination of multiple MLFunctionMatches as a new object.
static MLFunctionMatches combine(ArrayRef<MLFunctionMatches> matches) {
MLFunctionMatches res;
for (auto s : matches) {
for (auto ss : s) {
res.append(ss.first, ss.second);
}
}
return res;
}
/// Calls walk on `function`.
MLFunctionMatches &MLFunctionMatcher::match(MLFunction *function) {
assert(!matches && "MLFunctionMatcher already matched!");
this->walk(function);
return matches;
}
/// Calls walk on `statement`.
MLFunctionMatches &MLFunctionMatcher::match(Statement *statement) {
assert(!matches && "MLFunctionMatcher already matched!");
this->walk(statement);
return matches;
}
/// matchOrSkipOne is needed so that we can implement match without switching on
/// the type of the Statement.
/// The idea is that a MLFunctionMatcher first checks if it matches locally and
/// then recursively applies its children matchers to its elem->children.
/// Since we want to rely on the StmtWalker impl rather than duplicate its
/// the logic, we allow an off-by-one traversal to account for the fact that
/// we write:
///
/// void match(Statement *elem) {
/// for (auto &c : getChildrenMLFunctionMatchers()) {
/// MLFunctionMatcher childMLFunctionMatcher(...);
/// childMLFunctionMatcher.walk(elem); <~~~ Needs off-by-one traversal.
///
void MLFunctionMatcher::matchOrSkipOne(Statement *elem) {
if (skipOne) {
skipOne = false;
return;
}
matchOne(elem);
}
/// Matches a single statement in the following way:
/// 1. checks the kind of statement against the matcher, if different then
/// there is no match;
/// 2. calls the customizable filter function to refine the single statement
/// match with extra semantic constraints;
/// 3. if all is good, recursivey matches the children patterns;
/// 4. if all children match then the single statement matches too and is
/// appended to the list of matches;
/// 5. TODO(ntv) Optionally applies actions (lambda), in which case we will
/// want to traverse in post-order DFS to avoid invalidating iterators.
void MLFunctionMatcher::matchOne(Statement *elem) {
// Structural filter
if (elem->getKind() != getKind()) {
return;
}
// Local custom filter function
if (!getFilterFunction()(elem)) {
return;
}
SmallVector<MLFunctionMatches, 8> childrenMLFunctionMatches;
for (auto &c : getChildrenMLFunctionMatchers()) {
/// We create a new childMLFunctionMatcher here because a matcher holds its
/// results So we concretely need multiple copies of a given matcher, one
/// for each matching result.
MLFunctionMatcher childMLFunctionMatcher = forkMLFunctionMatcher(c);
childMLFunctionMatcher.walk(elem);
if (!childMLFunctionMatcher.matches) {
return;
}
childrenMLFunctionMatches.push_back(childMLFunctionMatcher.matches);
}
matches.append(elem, combine(childrenMLFunctionMatches));
}
llvm::BumpPtrAllocator *&MLFunctionMatcher::allocator() {
static thread_local llvm::BumpPtrAllocator *allocator = nullptr;
return allocator;
}
MLFunctionMatcher::MLFunctionMatcher(Statement::Kind k, MLFunctionMatcher child,
FilterFunctionType filter)
: storage(allocator()->Allocate<MLFunctionMatcherStorage>()),
skipOne(false) {
// Initialize with placement new.
new (storage) MLFunctionMatcherStorage(k, {child}, filter);
}
MLFunctionMatcher::MLFunctionMatcher(
Statement::Kind k, MutableArrayRef<MLFunctionMatcher> children,
FilterFunctionType filter)
: storage(allocator()->Allocate<MLFunctionMatcherStorage>()),
skipOne(false) {
// Initialize with placement new.
new (storage) MLFunctionMatcherStorage(k, children, filter);
}
MLFunctionMatcher
MLFunctionMatcher::forkMLFunctionMatcher(MLFunctionMatcher tmpl) {
MLFunctionMatcher res(tmpl.getKind(), tmpl.getChildrenMLFunctionMatchers(),
tmpl.getFilterFunction());
res.skipOne = true;
return res;
}
Statement::Kind MLFunctionMatcher::getKind() { return storage->kind; }
MutableArrayRef<MLFunctionMatcher>
MLFunctionMatcher::getChildrenMLFunctionMatchers() {
return storage->childrenMLFunctionMatchers;
}
FilterFunctionType MLFunctionMatcher::getFilterFunction() {
return storage->filter;
}
namespace mlir {
namespace matcher {
MLFunctionMatcher Op(FilterFunctionType filter) {
return MLFunctionMatcher(Statement::Kind::Operation, {}, filter);
}
MLFunctionMatcher If(MLFunctionMatcher child) {
return MLFunctionMatcher(Statement::Kind::If, child, defaultFilterFunction);
}
MLFunctionMatcher If(FilterFunctionType filter, MLFunctionMatcher child) {
return MLFunctionMatcher(Statement::Kind::If, child, filter);
}
MLFunctionMatcher If(MutableArrayRef<MLFunctionMatcher> children) {
return MLFunctionMatcher(Statement::Kind::If, children,
defaultFilterFunction);
}
MLFunctionMatcher If(FilterFunctionType filter,
MutableArrayRef<MLFunctionMatcher> children) {
return MLFunctionMatcher(Statement::Kind::If, children, filter);
}
MLFunctionMatcher For(MLFunctionMatcher child) {
return MLFunctionMatcher(Statement::Kind::For, child, defaultFilterFunction);
}
MLFunctionMatcher For(FilterFunctionType filter, MLFunctionMatcher child) {
return MLFunctionMatcher(Statement::Kind::For, child, filter);
}
MLFunctionMatcher For(MutableArrayRef<MLFunctionMatcher> children) {
return MLFunctionMatcher(Statement::Kind::For, children,
defaultFilterFunction);
}
MLFunctionMatcher For(FilterFunctionType filter,
MutableArrayRef<MLFunctionMatcher> children) {
return MLFunctionMatcher(Statement::Kind::For, children, filter);
}
// TODO(ntv): parallel annotation on loops.
FilterFunctionType isParallelLoop = [](Statement *stmt) {
auto *loop = cast<ForStmt>(stmt);
return (void *)loop || true; // loop->isParallel();
};
MLFunctionMatcher Doall(MLFunctionMatcher child) {
return MLFunctionMatcher(Statement::Kind::For, child, isParallelLoop);
}
MLFunctionMatcher Doall(MutableArrayRef<MLFunctionMatcher> children) {
return MLFunctionMatcher(Statement::Kind::For, children, isParallelLoop);
}
// TODO(ntv): reduction annotation on loops.
FilterFunctionType isReductionLoop = [](Statement *stmt) {
auto *loop = cast<ForStmt>(stmt);
return (void *)loop || true; // loop->isReduction();
};
MLFunctionMatcher Red(MLFunctionMatcher child) {
return MLFunctionMatcher(Statement::Kind::For, child, isReductionLoop);
}
MLFunctionMatcher Red(MutableArrayRef<MLFunctionMatcher> children) {
return MLFunctionMatcher(Statement::Kind::For, children, isReductionLoop);
}
FilterFunctionType isLoadOrStore = [](Statement *stmt) {
auto *opStmt = dyn_cast<OperationStmt>(stmt);
return opStmt && (opStmt->isa<LoadOp>() || opStmt->isa<StoreOp>());
[MLIR] Basic infrastructure for vectorization test This CL implements a very simple loop vectorization **test** and the basic infrastructure to support it. The test simply consists in: 1. matching the loops in the MLFunction and all the Load/Store operations nested under the loop; 2. testing whether all the Load/Store are contiguous along the innermost memory dimension along that particular loop. If any reference is non-contiguous (i.e. the ForStmt SSAValue appears in the expression), then the loop is not-vectorizable. The simple test above can gradually be extended with more interesting behaviors to account for the fact that a layout permutation may exist that enables contiguity etc. All these will come in due time but it is worthwhile noting that the test already supports detection of outer-vetorizable loops. In implementing this test, I also added a recursive MLFunctionMatcher and some sugar that can capture patterns such as `auto gemmLike = Doall(Doall(Red(LoadStore())))` and allows iterating on the matched IR structures. For now it just uses in order traversal but post-order DFS will be useful in the future once IR rewrites start occuring. One may note that the memory management design decision follows a different pattern from MLIR. After evaluating different designs and how they quickly increase cognitive overhead, I decided to opt for the simplest solution in my view: a class-wide (threadsafe) RAII context. This way, a pass that needs MLFunctionMatcher can just have its own locally scoped BumpPtrAllocator and everything is cleaned up when the pass is destroyed. If passes are expected to have a longer lifetime, then the contexts can easily be scoped inside the runOnMLFunction call and storage lifetime reduced. Lastly, whatever the scope of threading (module, function, pass), this is expected to also be future-proof wrt concurrency (but this is a detail atm). PiperOrigin-RevId: 217622889
2018-10-18 09:01:44 +08:00
};
MLFunctionMatcher LoadStores() {
return MLFunctionMatcher(Statement::Kind::Operation, {}, isLoadOrStore);
}
} // end namespace matcher
} // end namespace mlir