forked from OSchip/llvm-project
Cleanup resource management and rename recursive matchers
This CL follows up on a memory leak issue related to SmallVector growth that escapes the BumpPtrAllocator. The fix is to properly use ArrayRef and placement new to define away the issue. The following renaming is also applied: 1. MLFunctionMatcher -> NestedPattern 2. MLFunctionMatches -> NestedMatch As a consequence all allocations are now guaranteed to live on the BumpPtrAllocator. PiperOrigin-RevId: 231047766
This commit is contained in:
parent
75c21e1de0
commit
81c7f2e2f3
|
@ -1,4 +1,4 @@
|
|||
//===- MLFunctionMacher.h - Recursive matcher for MLFunction ----*- C++ -*-===//
|
||||
//===- NestedMacher.h - Nested matcher for MLFunction -----------*- C++ -*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
|
@ -24,22 +24,22 @@
|
|||
|
||||
namespace mlir {
|
||||
|
||||
struct MLFunctionMatcherStorage;
|
||||
struct MLFunctionMatchesStorage;
|
||||
struct NestedPatternStorage;
|
||||
struct NestedMatchStorage;
|
||||
class Instruction;
|
||||
|
||||
/// An MLFunctionMatcher is a recursive matcher that captures nested patterns in
|
||||
/// an ML Function. It is used in conjunction with a scoped
|
||||
/// MLFunctionMatcherContext that handles the memory allocations efficiently.
|
||||
/// An NestedPattern captures nested patterns. It is used in conjunction with
|
||||
/// a scoped NestedPatternContext which is an llvm::BumPtrAllocator that
|
||||
/// handles memory allocations efficiently and avoids ownership issues.
|
||||
///
|
||||
/// In order to use MLFunctionMatchers creates a scoped context and uses
|
||||
/// matchers. When the context goes out of scope, everything is freed.
|
||||
/// In order to use NestedPatterns, first create a scoped context. When the
|
||||
/// context goes out of scope, everything is freed.
|
||||
/// This design simplifies the API by avoiding references to the context and
|
||||
/// makes it clear that references to matchers must not escape.
|
||||
///
|
||||
/// Example:
|
||||
/// {
|
||||
/// MLFunctionMatcherContext context;
|
||||
/// NestedPatternContext context;
|
||||
/// auto gemmLike = Doall(Doall(Red(LoadStores())));
|
||||
/// auto matches = gemmLike.match(f);
|
||||
/// // do work on matches
|
||||
|
@ -51,15 +51,17 @@ class Instruction;
|
|||
///
|
||||
/// Implemented as a POD value-type with underlying storage pointer.
|
||||
/// The underlying storage lives in a scoped bumper allocator whose lifetime
|
||||
/// is managed by an RAII MLFunctionMatcherContext.
|
||||
/// This should be used by value everywhere.
|
||||
struct MLFunctionMatches {
|
||||
using EntryType = std::pair<Instruction *, MLFunctionMatches>;
|
||||
/// is managed by an RAII NestedPatternContext.
|
||||
/// This is used by value everywhere.
|
||||
struct NestedMatch {
|
||||
using EntryType = std::pair<Instruction *, NestedMatch>;
|
||||
using iterator = EntryType *;
|
||||
|
||||
MLFunctionMatches() : storage(nullptr) {}
|
||||
static NestedMatch build(ArrayRef<NestedMatch::EntryType> elements = {});
|
||||
NestedMatch(const NestedMatch &) = default;
|
||||
NestedMatch &operator=(const NestedMatch &) = default;
|
||||
|
||||
explicit operator bool() { return storage; }
|
||||
explicit operator bool() { return !empty(); }
|
||||
|
||||
iterator begin();
|
||||
iterator end();
|
||||
|
@ -68,20 +70,25 @@ struct MLFunctionMatches {
|
|||
unsigned size() { return end() - begin(); }
|
||||
unsigned empty() { return size() == 0; }
|
||||
|
||||
/// Appends the pair <inst, children> to the current matches.
|
||||
void append(Instruction *inst, MLFunctionMatches children);
|
||||
|
||||
private:
|
||||
friend class MLFunctionMatcher;
|
||||
friend class MLFunctionMatcherContext;
|
||||
friend class NestedPattern;
|
||||
friend class NestedPatternContext;
|
||||
friend class NestedMatchStorage;
|
||||
|
||||
/// Underlying global bump allocator managed by an MLFunctionMatcherContext.
|
||||
/// Underlying global bump allocator managed by a NestedPatternContext.
|
||||
static llvm::BumpPtrAllocator *&allocator();
|
||||
|
||||
MLFunctionMatchesStorage *storage;
|
||||
NestedMatch(NestedMatchStorage *storage) : storage(storage){};
|
||||
|
||||
/// Copy the specified array of elements into memory managed by our bump
|
||||
/// pointer allocator. The elements are all PODs by constructions.
|
||||
static NestedMatch copyInto(ArrayRef<NestedMatch::EntryType> elements);
|
||||
|
||||
/// POD payload.
|
||||
NestedMatchStorage *storage;
|
||||
};
|
||||
|
||||
/// A MLFunctionMatcher is a special type of InstWalker that:
|
||||
/// A NestedPattern is a special type of InstWalker that:
|
||||
/// 1. recursively matches a substructure in the tree;
|
||||
/// 2. uses a filter function to refine matches with extra semantic
|
||||
/// constraints (passed via a lambda of type FilterFunctionType);
|
||||
|
@ -89,78 +96,76 @@ private:
|
|||
///
|
||||
/// Implemented as a POD value-type with underlying storage pointer.
|
||||
/// The underlying storage lives in a scoped bumper allocator whose lifetime
|
||||
/// is managed by an RAII MLFunctionMatcherContext.
|
||||
/// is managed by an RAII NestedPatternContext.
|
||||
/// This should be used by value everywhere.
|
||||
using FilterFunctionType = std::function<bool(const Instruction &)>;
|
||||
static bool defaultFilterFunction(const Instruction &) { return true; };
|
||||
struct MLFunctionMatcher : public InstWalker<MLFunctionMatcher> {
|
||||
MLFunctionMatcher(Instruction::Kind k, MLFunctionMatcher child,
|
||||
FilterFunctionType filter = defaultFilterFunction);
|
||||
MLFunctionMatcher(Instruction::Kind k,
|
||||
MutableArrayRef<MLFunctionMatcher> children,
|
||||
FilterFunctionType filter = defaultFilterFunction);
|
||||
struct NestedPattern : public InstWalker<NestedPattern> {
|
||||
NestedPattern(Instruction::Kind k, ArrayRef<NestedPattern> nested,
|
||||
FilterFunctionType filter = defaultFilterFunction);
|
||||
NestedPattern(const NestedPattern &) = default;
|
||||
NestedPattern &operator=(const NestedPattern &) = default;
|
||||
|
||||
/// Returns all the matches in `function`.
|
||||
MLFunctionMatches match(Function *function);
|
||||
NestedMatch match(Function *function);
|
||||
|
||||
/// Returns all the matches nested under `instruction`.
|
||||
MLFunctionMatches match(Instruction *instruction);
|
||||
NestedMatch match(Instruction *instruction);
|
||||
|
||||
unsigned getDepth();
|
||||
|
||||
private:
|
||||
friend class MLFunctionMatcherContext;
|
||||
friend InstWalker<MLFunctionMatcher>;
|
||||
friend class NestedPatternContext;
|
||||
friend InstWalker<NestedPattern>;
|
||||
|
||||
/// Underlying global bump allocator managed by a NestedPatternContext.
|
||||
static llvm::BumpPtrAllocator *&allocator();
|
||||
|
||||
Instruction::Kind getKind();
|
||||
MutableArrayRef<MLFunctionMatcher> getChildrenMLFunctionMatchers();
|
||||
ArrayRef<NestedPattern> getNestedPatterns();
|
||||
FilterFunctionType getFilterFunction();
|
||||
|
||||
MLFunctionMatcher forkMLFunctionMatcherAt(MLFunctionMatcher tmpl,
|
||||
Instruction *inst);
|
||||
|
||||
void matchOne(Instruction *elem);
|
||||
|
||||
void visitForInst(ForInst *forInst) { matchOne(forInst); }
|
||||
void visitIfInst(IfInst *ifInst) { matchOne(ifInst); }
|
||||
void visitOperationInst(OperationInst *opInst) { matchOne(opInst); }
|
||||
|
||||
/// Underlying global bump allocator managed by an MLFunctionMatcherContext.
|
||||
static llvm::BumpPtrAllocator *&allocator();
|
||||
|
||||
MLFunctionMatcherStorage *storage;
|
||||
/// POD paylod.
|
||||
/// Storage for the PatternMatcher.
|
||||
NestedPatternStorage *storage;
|
||||
|
||||
// By-value POD wrapper to underlying storage pointer.
|
||||
MLFunctionMatches matches;
|
||||
NestedMatch matches;
|
||||
};
|
||||
|
||||
/// RAII structure to transparently manage the bump allocator for
|
||||
/// MLFunctionMatcher and MLFunctionMatches classes.
|
||||
struct MLFunctionMatcherContext {
|
||||
MLFunctionMatcherContext() {
|
||||
MLFunctionMatcher::allocator() = &allocator;
|
||||
MLFunctionMatches::allocator() = &allocator;
|
||||
/// NestedPattern and NestedMatch classes.
|
||||
struct NestedPatternContext {
|
||||
NestedPatternContext() {
|
||||
NestedPattern::allocator() = &allocator;
|
||||
NestedMatch::allocator() = &allocator;
|
||||
}
|
||||
~MLFunctionMatcherContext() {
|
||||
MLFunctionMatcher::allocator() = nullptr;
|
||||
MLFunctionMatches::allocator() = nullptr;
|
||||
~NestedPatternContext() {
|
||||
NestedPattern::allocator() = nullptr;
|
||||
NestedMatch::allocator() = nullptr;
|
||||
}
|
||||
llvm::BumpPtrAllocator allocator;
|
||||
};
|
||||
|
||||
namespace matcher {
|
||||
// Syntactic sugar MLFunctionMatcher builder functions.
|
||||
MLFunctionMatcher Op(FilterFunctionType filter = defaultFilterFunction);
|
||||
MLFunctionMatcher If(MLFunctionMatcher child);
|
||||
MLFunctionMatcher If(FilterFunctionType filter, MLFunctionMatcher child);
|
||||
MLFunctionMatcher If(MutableArrayRef<MLFunctionMatcher> children = {});
|
||||
MLFunctionMatcher If(FilterFunctionType filter,
|
||||
MutableArrayRef<MLFunctionMatcher> children = {});
|
||||
MLFunctionMatcher For(MLFunctionMatcher child);
|
||||
MLFunctionMatcher For(FilterFunctionType filter, MLFunctionMatcher child);
|
||||
MLFunctionMatcher For(MutableArrayRef<MLFunctionMatcher> children = {});
|
||||
MLFunctionMatcher For(FilterFunctionType filter,
|
||||
MutableArrayRef<MLFunctionMatcher> children = {});
|
||||
// Syntactic sugar NestedPattern builder functions.
|
||||
NestedPattern Op(FilterFunctionType filter = defaultFilterFunction);
|
||||
NestedPattern If(NestedPattern child);
|
||||
NestedPattern If(FilterFunctionType filter, NestedPattern child);
|
||||
NestedPattern If(ArrayRef<NestedPattern> nested = {});
|
||||
NestedPattern If(FilterFunctionType filter,
|
||||
ArrayRef<NestedPattern> nested = {});
|
||||
NestedPattern For(NestedPattern child);
|
||||
NestedPattern For(FilterFunctionType filter, NestedPattern child);
|
||||
NestedPattern For(ArrayRef<NestedPattern> nested = {});
|
||||
NestedPattern For(FilterFunctionType filter,
|
||||
ArrayRef<NestedPattern> nested = {});
|
||||
|
||||
bool isParallelLoop(const Instruction &inst);
|
||||
bool isReductionLoop(const Instruction &inst);
|
|
@ -23,7 +23,7 @@
|
|||
// generally designed to be automatically generated from various IR dialects in
|
||||
// the future.
|
||||
// The implementation is supported by a lightweight by-value abstraction on a
|
||||
// scoped BumpAllocator with similarities to AffineExpr and MLFunctionMatcher.
|
||||
// scoped BumpAllocator with similarities to AffineExpr and NestedPattern.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@
|
|||
|
||||
#include "mlir/Analysis/AffineAnalysis.h"
|
||||
#include "mlir/Analysis/AffineStructures.h"
|
||||
#include "mlir/Analysis/MLFunctionMatcher.h"
|
||||
#include "mlir/Analysis/NestedMatcher.h"
|
||||
#include "mlir/Analysis/VectorAnalysis.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
|
|
|
@ -1,263 +0,0 @@
|
|||
//===- MLFunctionMatcher.cpp - MLFunctionMatcher Impl ----------*- C++ -*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#include "mlir/Analysis/MLFunctionMatcher.h"
|
||||
#include "mlir/StandardOps/StandardOps.h"
|
||||
|
||||
#include "llvm/Support/Allocator.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
/// Underlying storage for MLFunctionMatches.
|
||||
struct MLFunctionMatchesStorage {
|
||||
MLFunctionMatchesStorage(MLFunctionMatches::EntryType e) : matches({e}) {}
|
||||
|
||||
SmallVector<MLFunctionMatches::EntryType, 8> matches;
|
||||
};
|
||||
|
||||
/// Underlying storage for MLFunctionMatcher.
|
||||
struct MLFunctionMatcherStorage {
|
||||
MLFunctionMatcherStorage(Instruction::Kind k,
|
||||
MutableArrayRef<MLFunctionMatcher> c,
|
||||
FilterFunctionType filter, Instruction *skip)
|
||||
: kind(k), childrenMLFunctionMatchers(c.begin(), c.end()), filter(filter),
|
||||
skip(skip) {}
|
||||
|
||||
Instruction::Kind kind;
|
||||
SmallVector<MLFunctionMatcher, 4> childrenMLFunctionMatchers;
|
||||
FilterFunctionType filter;
|
||||
/// skip is needed so that we can implement match without switching on the
|
||||
/// type of the Instruction.
|
||||
/// The idea is that a MLFunctionMatcher first checks if it matches locally
|
||||
/// and then recursively applies its children matchers to its elem->children.
|
||||
/// Since we want to rely on the InstWalker impl rather than duplicate its
|
||||
/// the logic, we allow an off-by-one traversal to account for the fact that
|
||||
/// we write:
|
||||
///
|
||||
/// void match(Instruction *elem) {
|
||||
/// for (auto &c : getChildrenMLFunctionMatchers()) {
|
||||
/// MLFunctionMatcher childMLFunctionMatcher(...);
|
||||
/// ^~~~ Needs off-by-one skip.
|
||||
///
|
||||
Instruction *skip;
|
||||
};
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
llvm::BumpPtrAllocator *&MLFunctionMatches::allocator() {
|
||||
static thread_local llvm::BumpPtrAllocator *allocator = nullptr;
|
||||
return allocator;
|
||||
}
|
||||
|
||||
void MLFunctionMatches::append(Instruction *inst, MLFunctionMatches children) {
|
||||
if (!storage) {
|
||||
storage = allocator()->Allocate<MLFunctionMatchesStorage>();
|
||||
new (storage) MLFunctionMatchesStorage(std::make_pair(inst, children));
|
||||
} else {
|
||||
storage->matches.push_back(std::make_pair(inst, children));
|
||||
}
|
||||
}
|
||||
MLFunctionMatches::iterator MLFunctionMatches::begin() {
|
||||
return storage ? storage->matches.begin() : nullptr;
|
||||
}
|
||||
MLFunctionMatches::iterator MLFunctionMatches::end() {
|
||||
return storage ? storage->matches.end() : nullptr;
|
||||
}
|
||||
MLFunctionMatches::EntryType &MLFunctionMatches::front() {
|
||||
assert(storage && "null storage");
|
||||
return *storage->matches.begin();
|
||||
}
|
||||
MLFunctionMatches::EntryType &MLFunctionMatches::back() {
|
||||
assert(storage && "null storage");
|
||||
return *(storage->matches.begin() + size() - 1);
|
||||
}
|
||||
/// Return the combination of multiple MLFunctionMatches as a new object.
|
||||
static MLFunctionMatches combine(ArrayRef<MLFunctionMatches> matches) {
|
||||
MLFunctionMatches res;
|
||||
for (auto s : matches) {
|
||||
for (auto ss : s) {
|
||||
res.append(ss.first, ss.second);
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
/// Calls walk on `function`.
|
||||
MLFunctionMatches MLFunctionMatcher::match(Function *function) {
|
||||
assert(!matches && "MLFunctionMatcher already matched!");
|
||||
this->walkPostOrder(function);
|
||||
return matches;
|
||||
}
|
||||
|
||||
/// Calls walk on `instruction`.
|
||||
MLFunctionMatches MLFunctionMatcher::match(Instruction *instruction) {
|
||||
assert(!matches && "MLFunctionMatcher already matched!");
|
||||
this->walkPostOrder(instruction);
|
||||
return matches;
|
||||
}
|
||||
|
||||
unsigned MLFunctionMatcher::getDepth() {
|
||||
auto children = getChildrenMLFunctionMatchers();
|
||||
if (children.empty()) {
|
||||
return 1;
|
||||
}
|
||||
unsigned depth = 0;
|
||||
for (auto &c : children) {
|
||||
depth = std::max(depth, c.getDepth());
|
||||
}
|
||||
return depth + 1;
|
||||
}
|
||||
|
||||
/// Matches a single instruction in the following way:
|
||||
/// 1. checks the kind of instruction against the matcher, if different then
|
||||
/// there is no match;
|
||||
/// 2. calls the customizable filter function to refine the single instruction
|
||||
/// match with extra semantic constraints;
|
||||
/// 3. if all is good, recursivey matches the children patterns;
|
||||
/// 4. if all children match then the single instruction matches too and is
|
||||
/// appended to the list of matches;
|
||||
/// 5. TODO(ntv) Optionally applies actions (lambda), in which case we will
|
||||
/// want to traverse in post-order DFS to avoid invalidating iterators.
|
||||
void MLFunctionMatcher::matchOne(Instruction *elem) {
|
||||
if (storage->skip == elem) {
|
||||
return;
|
||||
}
|
||||
// Structural filter
|
||||
if (elem->getKind() != getKind()) {
|
||||
return;
|
||||
}
|
||||
// Local custom filter function
|
||||
if (!getFilterFunction()(*elem)) {
|
||||
return;
|
||||
}
|
||||
SmallVector<MLFunctionMatches, 8> childrenMLFunctionMatches;
|
||||
for (auto &c : getChildrenMLFunctionMatchers()) {
|
||||
/// We create a new childMLFunctionMatcher here because a matcher holds its
|
||||
/// results. So we concretely need multiple copies of a given matcher, one
|
||||
/// for each matching result.
|
||||
MLFunctionMatcher childMLFunctionMatcher = forkMLFunctionMatcherAt(c, elem);
|
||||
childMLFunctionMatcher.walkPostOrder(elem);
|
||||
if (!childMLFunctionMatcher.matches) {
|
||||
return;
|
||||
}
|
||||
childrenMLFunctionMatches.push_back(childMLFunctionMatcher.matches);
|
||||
}
|
||||
matches.append(elem, combine(childrenMLFunctionMatches));
|
||||
}
|
||||
|
||||
llvm::BumpPtrAllocator *&MLFunctionMatcher::allocator() {
|
||||
static thread_local llvm::BumpPtrAllocator *allocator = nullptr;
|
||||
return allocator;
|
||||
}
|
||||
|
||||
MLFunctionMatcher::MLFunctionMatcher(Instruction::Kind k,
|
||||
MLFunctionMatcher child,
|
||||
FilterFunctionType filter)
|
||||
: storage(allocator()->Allocate<MLFunctionMatcherStorage>()) {
|
||||
// Initialize with placement new.
|
||||
new (storage)
|
||||
MLFunctionMatcherStorage(k, {child}, filter, nullptr /* skip */);
|
||||
}
|
||||
|
||||
MLFunctionMatcher::MLFunctionMatcher(
|
||||
Instruction::Kind k, MutableArrayRef<MLFunctionMatcher> children,
|
||||
FilterFunctionType filter)
|
||||
: storage(allocator()->Allocate<MLFunctionMatcherStorage>()) {
|
||||
// Initialize with placement new.
|
||||
new (storage)
|
||||
MLFunctionMatcherStorage(k, children, filter, nullptr /* skip */);
|
||||
}
|
||||
|
||||
MLFunctionMatcher
|
||||
MLFunctionMatcher::forkMLFunctionMatcherAt(MLFunctionMatcher tmpl,
|
||||
Instruction *inst) {
|
||||
MLFunctionMatcher res(tmpl.getKind(), tmpl.getChildrenMLFunctionMatchers(),
|
||||
tmpl.getFilterFunction());
|
||||
res.storage->skip = inst;
|
||||
return res;
|
||||
}
|
||||
|
||||
Instruction::Kind MLFunctionMatcher::getKind() { return storage->kind; }
|
||||
|
||||
MutableArrayRef<MLFunctionMatcher>
|
||||
MLFunctionMatcher::getChildrenMLFunctionMatchers() {
|
||||
return storage->childrenMLFunctionMatchers;
|
||||
}
|
||||
|
||||
FilterFunctionType MLFunctionMatcher::getFilterFunction() {
|
||||
return storage->filter;
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
namespace matcher {
|
||||
|
||||
MLFunctionMatcher Op(FilterFunctionType filter) {
|
||||
return MLFunctionMatcher(Instruction::Kind::OperationInst, {}, filter);
|
||||
}
|
||||
|
||||
MLFunctionMatcher If(MLFunctionMatcher child) {
|
||||
return MLFunctionMatcher(Instruction::Kind::If, child, defaultFilterFunction);
|
||||
}
|
||||
MLFunctionMatcher If(FilterFunctionType filter, MLFunctionMatcher child) {
|
||||
return MLFunctionMatcher(Instruction::Kind::If, child, filter);
|
||||
}
|
||||
MLFunctionMatcher If(MutableArrayRef<MLFunctionMatcher> children) {
|
||||
return MLFunctionMatcher(Instruction::Kind::If, children,
|
||||
defaultFilterFunction);
|
||||
}
|
||||
MLFunctionMatcher If(FilterFunctionType filter,
|
||||
MutableArrayRef<MLFunctionMatcher> children) {
|
||||
return MLFunctionMatcher(Instruction::Kind::If, children, filter);
|
||||
}
|
||||
|
||||
MLFunctionMatcher For(MLFunctionMatcher child) {
|
||||
return MLFunctionMatcher(Instruction::Kind::For, child,
|
||||
defaultFilterFunction);
|
||||
}
|
||||
MLFunctionMatcher For(FilterFunctionType filter, MLFunctionMatcher child) {
|
||||
return MLFunctionMatcher(Instruction::Kind::For, child, filter);
|
||||
}
|
||||
MLFunctionMatcher For(MutableArrayRef<MLFunctionMatcher> children) {
|
||||
return MLFunctionMatcher(Instruction::Kind::For, children,
|
||||
defaultFilterFunction);
|
||||
}
|
||||
MLFunctionMatcher For(FilterFunctionType filter,
|
||||
MutableArrayRef<MLFunctionMatcher> children) {
|
||||
return MLFunctionMatcher(Instruction::Kind::For, children, filter);
|
||||
}
|
||||
|
||||
// TODO(ntv): parallel annotation on loops.
|
||||
bool isParallelLoop(const Instruction &inst) {
|
||||
const auto *loop = cast<ForInst>(&inst);
|
||||
return (void *)loop || true; // loop->isParallel();
|
||||
};
|
||||
|
||||
// TODO(ntv): reduction annotation on loops.
|
||||
bool isReductionLoop(const Instruction &inst) {
|
||||
const auto *loop = cast<ForInst>(&inst);
|
||||
return (void *)loop || true; // loop->isReduction();
|
||||
};
|
||||
|
||||
bool isLoadOrStore(const Instruction &inst) {
|
||||
const auto *opInst = dyn_cast<OperationInst>(&inst);
|
||||
return opInst && (opInst->isa<LoadOp>() || opInst->isa<StoreOp>());
|
||||
};
|
||||
|
||||
} // end namespace matcher
|
||||
} // end namespace mlir
|
|
@ -0,0 +1,240 @@
|
|||
//===- NestedMatcher.cpp - NestedMatcher Impl ------------------*- C++ -*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#include "mlir/Analysis/NestedMatcher.h"
|
||||
#include "mlir/StandardOps/StandardOps.h"
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/Support/Allocator.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
/// Underlying storage for NestedMatch.
|
||||
struct NestedMatchStorage {
|
||||
MutableArrayRef<NestedMatch::EntryType> matches;
|
||||
};
|
||||
|
||||
/// Underlying storage for NestedPattern.
|
||||
struct NestedPatternStorage {
|
||||
NestedPatternStorage(Instruction::Kind k, ArrayRef<NestedPattern> c,
|
||||
FilterFunctionType filter, Instruction *skip)
|
||||
: kind(k), nestedPatterns(c), filter(filter), skip(skip) {}
|
||||
|
||||
Instruction::Kind kind;
|
||||
ArrayRef<NestedPattern> nestedPatterns;
|
||||
FilterFunctionType filter;
|
||||
/// skip is needed so that we can implement match without switching on the
|
||||
/// type of the Instruction.
|
||||
/// The idea is that a NestedPattern first checks if it matches locally
|
||||
/// and then recursively applies its nested matchers to its elem->nested.
|
||||
/// Since we want to rely on the InstWalker impl rather than duplicate its
|
||||
/// the logic, we allow an off-by-one traversal to account for the fact that
|
||||
/// we write:
|
||||
///
|
||||
/// void match(Instruction *elem) {
|
||||
/// for (auto &c : getNestedPatterns()) {
|
||||
/// NestedPattern childPattern(...);
|
||||
/// ^~~~ Needs off-by-one skip.
|
||||
///
|
||||
Instruction *skip;
|
||||
};
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
llvm::BumpPtrAllocator *&NestedMatch::allocator() {
|
||||
static thread_local llvm::BumpPtrAllocator *allocator = nullptr;
|
||||
return allocator;
|
||||
}
|
||||
|
||||
NestedMatch NestedMatch::build(ArrayRef<NestedMatch::EntryType> elements) {
|
||||
auto *matches =
|
||||
allocator()->Allocate<NestedMatch::EntryType>(elements.size());
|
||||
std::uninitialized_copy(elements.begin(), elements.end(), matches);
|
||||
auto *storage = allocator()->Allocate<NestedMatchStorage>();
|
||||
new (storage) NestedMatchStorage();
|
||||
storage->matches =
|
||||
MutableArrayRef<NestedMatch::EntryType>(matches, elements.size());
|
||||
auto *result = allocator()->Allocate<NestedMatch>();
|
||||
new (result) NestedMatch(storage);
|
||||
return *result;
|
||||
}
|
||||
|
||||
NestedMatch::iterator NestedMatch::begin() { return storage->matches.begin(); }
|
||||
NestedMatch::iterator NestedMatch::end() { return storage->matches.end(); }
|
||||
NestedMatch::EntryType &NestedMatch::front() {
|
||||
return *storage->matches.begin();
|
||||
}
|
||||
NestedMatch::EntryType &NestedMatch::back() {
|
||||
return *(storage->matches.begin() + size() - 1);
|
||||
}
|
||||
|
||||
/// Calls walk on `function`.
|
||||
NestedMatch NestedPattern::match(Function *function) {
|
||||
assert(!matches && "NestedPattern already matched!");
|
||||
this->walkPostOrder(function);
|
||||
return matches;
|
||||
}
|
||||
|
||||
/// Calls walk on `instruction`.
|
||||
NestedMatch NestedPattern::match(Instruction *instruction) {
|
||||
assert(!matches && "NestedPattern already matched!");
|
||||
this->walkPostOrder(instruction);
|
||||
return matches;
|
||||
}
|
||||
|
||||
unsigned NestedPattern::getDepth() {
|
||||
auto nested = getNestedPatterns();
|
||||
if (nested.empty()) {
|
||||
return 1;
|
||||
}
|
||||
unsigned depth = 0;
|
||||
for (auto c : nested) {
|
||||
depth = std::max(depth, c.getDepth());
|
||||
}
|
||||
return depth + 1;
|
||||
}
|
||||
|
||||
/// Matches a single instruction in the following way:
|
||||
/// 1. checks the kind of instruction against the matcher, if different then
|
||||
/// there is no match;
|
||||
/// 2. calls the customizable filter function to refine the single instruction
|
||||
/// match with extra semantic constraints;
|
||||
/// 3. if all is good, recursivey matches the nested patterns;
|
||||
/// 4. if all nested match then the single instruction matches too and is
|
||||
/// appended to the list of matches;
|
||||
/// 5. TODO(ntv) Optionally applies actions (lambda), in which case we will
|
||||
/// want to traverse in post-order DFS to avoid invalidating iterators.
|
||||
void NestedPattern::matchOne(Instruction *elem) {
|
||||
if (storage->skip == elem) {
|
||||
return;
|
||||
}
|
||||
// Structural filter
|
||||
if (elem->getKind() != getKind()) {
|
||||
return;
|
||||
}
|
||||
// Local custom filter function
|
||||
if (!getFilterFunction()(*elem)) {
|
||||
return;
|
||||
}
|
||||
|
||||
SmallVector<NestedMatch::EntryType, 8> nestedEntries;
|
||||
for (auto c : getNestedPatterns()) {
|
||||
/// We create a new nestedPattern here because a matcher holds its
|
||||
/// results. So we concretely need multiple copies of a given matcher, one
|
||||
/// for each matching result.
|
||||
NestedPattern nestedPattern(c);
|
||||
// Skip elem in the walk immediately following. Without this we would
|
||||
// essentially need to reimplement walkPostOrder here.
|
||||
nestedPattern.storage->skip = elem;
|
||||
nestedPattern.walkPostOrder(elem);
|
||||
if (!nestedPattern.matches) {
|
||||
return;
|
||||
}
|
||||
for (auto m : nestedPattern.matches) {
|
||||
nestedEntries.push_back(m);
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<NestedMatch::EntryType, 8> newEntries(
|
||||
matches.storage->matches.begin(), matches.storage->matches.end());
|
||||
newEntries.push_back(std::make_pair(elem, NestedMatch::build(nestedEntries)));
|
||||
matches = NestedMatch::build(newEntries);
|
||||
}
|
||||
|
||||
llvm::BumpPtrAllocator *&NestedPattern::allocator() {
|
||||
static thread_local llvm::BumpPtrAllocator *allocator = nullptr;
|
||||
return allocator;
|
||||
}
|
||||
|
||||
NestedPattern::NestedPattern(Instruction::Kind k,
|
||||
ArrayRef<NestedPattern> nested,
|
||||
FilterFunctionType filter)
|
||||
: storage(allocator()->Allocate<NestedPatternStorage>()),
|
||||
matches(NestedMatch::build({})) {
|
||||
auto *newChildren = allocator()->Allocate<NestedPattern>(nested.size());
|
||||
std::uninitialized_copy(nested.begin(), nested.end(), newChildren);
|
||||
// Initialize with placement new.
|
||||
new (storage) NestedPatternStorage(
|
||||
k, ArrayRef<NestedPattern>(newChildren, nested.size()), filter,
|
||||
nullptr /* skip */);
|
||||
}
|
||||
|
||||
Instruction::Kind NestedPattern::getKind() { return storage->kind; }
|
||||
|
||||
ArrayRef<NestedPattern> NestedPattern::getNestedPatterns() {
|
||||
return storage->nestedPatterns;
|
||||
}
|
||||
|
||||
FilterFunctionType NestedPattern::getFilterFunction() {
|
||||
return storage->filter;
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
namespace matcher {
|
||||
|
||||
NestedPattern Op(FilterFunctionType filter) {
|
||||
return NestedPattern(Instruction::Kind::OperationInst, {}, filter);
|
||||
}
|
||||
|
||||
NestedPattern If(NestedPattern child) {
|
||||
return NestedPattern(Instruction::Kind::If, child, defaultFilterFunction);
|
||||
}
|
||||
NestedPattern If(FilterFunctionType filter, NestedPattern child) {
|
||||
return NestedPattern(Instruction::Kind::If, child, filter);
|
||||
}
|
||||
NestedPattern If(ArrayRef<NestedPattern> nested) {
|
||||
return NestedPattern(Instruction::Kind::If, nested, defaultFilterFunction);
|
||||
}
|
||||
NestedPattern If(FilterFunctionType filter, ArrayRef<NestedPattern> nested) {
|
||||
return NestedPattern(Instruction::Kind::If, nested, filter);
|
||||
}
|
||||
|
||||
NestedPattern For(NestedPattern child) {
|
||||
return NestedPattern(Instruction::Kind::For, child, defaultFilterFunction);
|
||||
}
|
||||
NestedPattern For(FilterFunctionType filter, NestedPattern child) {
|
||||
return NestedPattern(Instruction::Kind::For, child, filter);
|
||||
}
|
||||
NestedPattern For(ArrayRef<NestedPattern> nested) {
|
||||
return NestedPattern(Instruction::Kind::For, nested, defaultFilterFunction);
|
||||
}
|
||||
NestedPattern For(FilterFunctionType filter, ArrayRef<NestedPattern> nested) {
|
||||
return NestedPattern(Instruction::Kind::For, nested, filter);
|
||||
}
|
||||
|
||||
// TODO(ntv): parallel annotation on loops.
|
||||
bool isParallelLoop(const Instruction &inst) {
|
||||
const auto *loop = cast<ForInst>(&inst);
|
||||
return (void *)loop || true; // loop->isParallel();
|
||||
};
|
||||
|
||||
// TODO(ntv): reduction annotation on loops.
|
||||
bool isReductionLoop(const Instruction &inst) {
|
||||
const auto *loop = cast<ForInst>(&inst);
|
||||
return (void *)loop || true; // loop->isReduction();
|
||||
};
|
||||
|
||||
bool isLoadOrStore(const Instruction &inst) {
|
||||
const auto *opInst = dyn_cast<OperationInst>(&inst);
|
||||
return opInst && (opInst->isa<LoadOp>() || opInst->isa<StoreOp>());
|
||||
};
|
||||
|
||||
} // end namespace matcher
|
||||
} // end namespace mlir
|
|
@ -22,7 +22,7 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Analysis/AffineAnalysis.h"
|
||||
#include "mlir/Analysis/MLFunctionMatcher.h"
|
||||
#include "mlir/Analysis/NestedMatcher.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
|
@ -49,7 +49,7 @@ struct ComposeAffineMaps : public FunctionPass {
|
|||
PassResult runOnFunction(Function *f) override;
|
||||
|
||||
// Thread-safe RAII contexts local to pass, BumpPtrAllocator freed on exit.
|
||||
MLFunctionMatcherContext MLContext;
|
||||
NestedPatternContext MLContext;
|
||||
|
||||
static char passID;
|
||||
};
|
||||
|
@ -74,8 +74,7 @@ PassResult ComposeAffineMaps::runOnFunction(Function *f) {
|
|||
auto apps = pattern.match(f);
|
||||
for (auto m : apps) {
|
||||
auto app = cast<OperationInst>(m.first)->cast<AffineApplyOp>();
|
||||
SmallVector<Value *, 8> operands(app->getOperands().begin(),
|
||||
app->getOperands().end());
|
||||
SmallVector<Value *, 8> operands(app->getOperands());
|
||||
FuncBuilder b(m.first);
|
||||
auto newApp = makeComposedAffineApply(&b, app->getLoc(),
|
||||
app->getAffineMap(), operands);
|
||||
|
|
|
@ -22,7 +22,7 @@
|
|||
#include <type_traits>
|
||||
|
||||
#include "mlir/Analysis/AffineAnalysis.h"
|
||||
#include "mlir/Analysis/MLFunctionMatcher.h"
|
||||
#include "mlir/Analysis/NestedMatcher.h"
|
||||
#include "mlir/Analysis/Utils.h"
|
||||
#include "mlir/Analysis/VectorAnalysis.h"
|
||||
#include "mlir/EDSC/MLIREmitter.h"
|
||||
|
|
|
@ -23,7 +23,7 @@
|
|||
#include "mlir/Analysis/AffineAnalysis.h"
|
||||
#include "mlir/Analysis/Dominance.h"
|
||||
#include "mlir/Analysis/LoopAnalysis.h"
|
||||
#include "mlir/Analysis/MLFunctionMatcher.h"
|
||||
#include "mlir/Analysis/NestedMatcher.h"
|
||||
#include "mlir/Analysis/SliceAnalysis.h"
|
||||
#include "mlir/Analysis/Utils.h"
|
||||
#include "mlir/Analysis/VectorAnalysis.h"
|
||||
|
@ -200,7 +200,7 @@ struct MaterializeVectorsPass : public FunctionPass {
|
|||
PassResult runOnFunction(Function *f) override;
|
||||
|
||||
// Thread-safe RAII contexts local to pass, BumpPtrAllocator freed on exit.
|
||||
MLFunctionMatcherContext mlContext;
|
||||
NestedPatternContext mlContext;
|
||||
|
||||
static char passID;
|
||||
};
|
||||
|
|
|
@ -20,7 +20,7 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Analysis/AffineAnalysis.h"
|
||||
#include "mlir/Analysis/MLFunctionMatcher.h"
|
||||
#include "mlir/Analysis/NestedMatcher.h"
|
||||
#include "mlir/Analysis/SliceAnalysis.h"
|
||||
#include "mlir/Analysis/VectorAnalysis.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
|
@ -95,7 +95,7 @@ struct VectorizerTestPass : public FunctionPass {
|
|||
void testNormalizeMaps(Function *f);
|
||||
|
||||
// Thread-safe RAII contexts local to pass, BumpPtrAllocator freed on exit.
|
||||
MLFunctionMatcherContext MLContext;
|
||||
NestedPatternContext MLContext;
|
||||
|
||||
static char passID;
|
||||
};
|
||||
|
@ -153,7 +153,7 @@ static std::string toString(Instruction *inst) {
|
|||
return res;
|
||||
}
|
||||
|
||||
static MLFunctionMatches matchTestSlicingOps(Function *f) {
|
||||
static NestedMatch matchTestSlicingOps(Function *f) {
|
||||
// Just use a custom op name for this test, it makes life easier.
|
||||
constexpr auto kTestSlicingOpName = "slicing-test-op";
|
||||
using functional::map;
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Analysis/LoopAnalysis.h"
|
||||
#include "mlir/Analysis/MLFunctionMatcher.h"
|
||||
#include "mlir/Analysis/NestedMatcher.h"
|
||||
#include "mlir/Analysis/VectorAnalysis.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
|
@ -567,14 +567,14 @@ static FilterFunctionType
|
|||
isVectorizableLoopPtrFactory(unsigned fastestVaryingMemRefDimension);
|
||||
|
||||
// Build a bunch of predetermined patterns that will be traversed in order.
|
||||
// Due to the recursive nature of MLFunctionMatchers, this captures
|
||||
// Due to the recursive nature of NestedPatterns, this captures
|
||||
// arbitrarily nested pairs of loops at any position in the tree.
|
||||
/// Note that this currently only matches 2 nested loops and will be extended.
|
||||
// TODO(ntv): support 3-D loop patterns with a common reduction loop that can
|
||||
// be matched to GEMMs.
|
||||
static std::vector<MLFunctionMatcher> defaultPatterns() {
|
||||
static std::vector<NestedPattern> defaultPatterns() {
|
||||
using matcher::For;
|
||||
return std::vector<MLFunctionMatcher>{
|
||||
return std::vector<NestedPattern>{
|
||||
// 3-D patterns
|
||||
For(isVectorizableLoopPtrFactory(2),
|
||||
For(isVectorizableLoopPtrFactory(1),
|
||||
|
@ -627,7 +627,7 @@ static std::vector<MLFunctionMatcher> defaultPatterns() {
|
|||
/// Up to 3-D patterns are supported.
|
||||
/// If the command line argument requests a pattern of higher order, returns an
|
||||
/// empty pattern list which will conservatively result in no vectorization.
|
||||
static std::vector<MLFunctionMatcher> makePatterns() {
|
||||
static std::vector<NestedPattern> makePatterns() {
|
||||
using matcher::For;
|
||||
if (clFastestVaryingPattern.empty()) {
|
||||
return defaultPatterns();
|
||||
|
@ -644,7 +644,7 @@ static std::vector<MLFunctionMatcher> makePatterns() {
|
|||
For(isVectorizableLoopPtrFactory(clFastestVaryingPattern[1]),
|
||||
For(isVectorizableLoopPtrFactory(clFastestVaryingPattern[2]))))};
|
||||
default:
|
||||
return std::vector<MLFunctionMatcher>();
|
||||
return std::vector<NestedPattern>();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -656,7 +656,7 @@ struct Vectorize : public FunctionPass {
|
|||
PassResult runOnFunction(Function *f) override;
|
||||
|
||||
// Thread-safe RAII contexts local to pass, BumpPtrAllocator freed on exit.
|
||||
MLFunctionMatcherContext MLContext;
|
||||
NestedPatternContext MLContext;
|
||||
|
||||
static char passID;
|
||||
};
|
||||
|
@ -703,8 +703,8 @@ static void vectorizeLoopIfProfitable(ForInst *loop, unsigned depthInPattern,
|
|||
/// 3. account for impact of vectorization on maximal loop fusion.
|
||||
/// Then we can quantify the above to build a cost model and search over
|
||||
/// strategies.
|
||||
static bool analyzeProfitability(MLFunctionMatches matches,
|
||||
unsigned depthInPattern, unsigned patternDepth,
|
||||
static bool analyzeProfitability(NestedMatch matches, unsigned depthInPattern,
|
||||
unsigned patternDepth,
|
||||
VectorizationStrategy *strategy) {
|
||||
for (auto m : matches) {
|
||||
auto *loop = cast<ForInst>(m.first);
|
||||
|
@ -890,7 +890,7 @@ static bool vectorizeForInst(ForInst *loop, int64_t step,
|
|||
return false;
|
||||
}
|
||||
|
||||
/// Returns a FilterFunctionType that can be used in MLFunctionMatcher to
|
||||
/// Returns a FilterFunctionType that can be used in NestedPattern to
|
||||
/// match a loop whose underlying load/store accesses are all varying along the
|
||||
/// `fastestVaryingMemRefDimension`.
|
||||
/// TODO(ntv): In the future, allow more interesting mixed layout permutation
|
||||
|
@ -906,16 +906,15 @@ isVectorizableLoopPtrFactory(unsigned fastestVaryingMemRefDimension) {
|
|||
}
|
||||
|
||||
/// Forward-declaration.
|
||||
static bool vectorizeNonRoot(MLFunctionMatches matches,
|
||||
VectorizationState *state);
|
||||
static bool vectorizeNonRoot(NestedMatch matches, VectorizationState *state);
|
||||
|
||||
/// Apply vectorization of `loop` according to `state`. This is only triggered
|
||||
/// if all vectorizations in `childrenMatches` have already succeeded
|
||||
/// recursively in DFS post-order.
|
||||
static bool doVectorize(MLFunctionMatches::EntryType oneMatch,
|
||||
static bool doVectorize(NestedMatch::EntryType oneMatch,
|
||||
VectorizationState *state) {
|
||||
ForInst *loop = cast<ForInst>(oneMatch.first);
|
||||
MLFunctionMatches childrenMatches = oneMatch.second;
|
||||
NestedMatch childrenMatches = oneMatch.second;
|
||||
|
||||
// 1. DFS postorder recursion, if any of my children fails, I fail too.
|
||||
auto fail = vectorizeNonRoot(childrenMatches, state);
|
||||
|
@ -949,8 +948,7 @@ static bool doVectorize(MLFunctionMatches::EntryType oneMatch,
|
|||
|
||||
/// Non-root pattern iterates over the matches at this level, calls doVectorize
|
||||
/// and exits early if anything below fails.
|
||||
static bool vectorizeNonRoot(MLFunctionMatches matches,
|
||||
VectorizationState *state) {
|
||||
static bool vectorizeNonRoot(NestedMatch matches, VectorizationState *state) {
|
||||
for (auto m : matches) {
|
||||
auto fail = doVectorize(m, state);
|
||||
if (fail) {
|
||||
|
@ -1186,7 +1184,7 @@ static bool vectorizeOperations(VectorizationState *state) {
|
|||
/// The root match thus needs to maintain a clone for handling failure.
|
||||
/// Each root may succeed independently but will otherwise clean after itself if
|
||||
/// anything below it fails.
|
||||
static bool vectorizeRootMatches(MLFunctionMatches matches,
|
||||
static bool vectorizeRootMatches(NestedMatch matches,
|
||||
VectorizationStrategy *strategy) {
|
||||
for (auto m : matches) {
|
||||
auto *loop = cast<ForInst>(m.first);
|
||||
|
|
Loading…
Reference in New Issue