2019-01-26 22:59:23 +08:00
|
|
|
//===- NestedMatcher.cpp - NestedMatcher Impl ------------------*- C++ -*-===//
|
|
|
|
//
|
|
|
|
// Copyright 2019 The MLIR Authors.
|
|
|
|
//
|
|
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
// you may not use this file except in compliance with the License.
|
|
|
|
// You may obtain a copy of the License at
|
|
|
|
//
|
|
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
//
|
|
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
// See the License for the specific language governing permissions and
|
|
|
|
// limitations under the License.
|
|
|
|
// =============================================================================
|
|
|
|
|
|
|
|
#include "mlir/Analysis/NestedMatcher.h"
|
2019-01-29 13:23:53 +08:00
|
|
|
#include "mlir/AffineOps/AffineOps.h"
|
2019-01-26 22:59:23 +08:00
|
|
|
#include "mlir/StandardOps/StandardOps.h"
|
|
|
|
|
|
|
|
#include "llvm/ADT/ArrayRef.h"
|
2019-01-31 23:16:29 +08:00
|
|
|
#include "llvm/ADT/STLExtras.h"
|
2019-01-26 22:59:23 +08:00
|
|
|
#include "llvm/Support/Allocator.h"
|
|
|
|
#include "llvm/Support/raw_ostream.h"
|
|
|
|
|
|
|
|
using namespace mlir;
|
|
|
|
|
|
|
|
llvm::BumpPtrAllocator *&NestedMatch::allocator() {
|
2019-01-31 23:16:29 +08:00
|
|
|
thread_local llvm::BumpPtrAllocator *allocator = nullptr;
|
2019-01-26 22:59:23 +08:00
|
|
|
return allocator;
|
|
|
|
}
|
|
|
|
|
2019-01-31 23:16:29 +08:00
|
|
|
NestedMatch NestedMatch::build(Instruction *instruction,
|
|
|
|
ArrayRef<NestedMatch> nestedMatches) {
|
2019-01-26 22:59:23 +08:00
|
|
|
auto *result = allocator()->Allocate<NestedMatch>();
|
2019-01-31 23:16:29 +08:00
|
|
|
auto *children = allocator()->Allocate<NestedMatch>(nestedMatches.size());
|
|
|
|
std::uninitialized_copy(nestedMatches.begin(), nestedMatches.end(), children);
|
|
|
|
new (result) NestedMatch();
|
|
|
|
result->matchedInstruction = instruction;
|
|
|
|
result->matchedChildren =
|
|
|
|
ArrayRef<NestedMatch>(children, nestedMatches.size());
|
2019-01-26 22:59:23 +08:00
|
|
|
return *result;
|
|
|
|
}
|
|
|
|
|
2019-01-31 23:16:29 +08:00
|
|
|
llvm::BumpPtrAllocator *&NestedPattern::allocator() {
|
|
|
|
thread_local llvm::BumpPtrAllocator *allocator = nullptr;
|
|
|
|
return allocator;
|
2019-01-26 22:59:23 +08:00
|
|
|
}
|
|
|
|
|
2019-01-31 23:16:29 +08:00
|
|
|
NestedPattern::NestedPattern(Instruction::Kind k,
|
|
|
|
ArrayRef<NestedPattern> nested,
|
|
|
|
FilterFunctionType filter)
|
2019-02-01 00:05:11 +08:00
|
|
|
: kind(k), nestedPatterns(), filter(filter), skip(nullptr) {
|
|
|
|
if (!nested.empty()) {
|
|
|
|
auto *newNested = allocator()->Allocate<NestedPattern>(nested.size());
|
|
|
|
std::uninitialized_copy(nested.begin(), nested.end(), newNested);
|
|
|
|
nestedPatterns = ArrayRef<NestedPattern>(newNested, nested.size());
|
|
|
|
}
|
2019-01-26 22:59:23 +08:00
|
|
|
}
|
|
|
|
|
2019-01-31 23:16:29 +08:00
|
|
|
unsigned NestedPattern::getDepth() const {
|
|
|
|
if (nestedPatterns.empty()) {
|
2019-01-26 22:59:23 +08:00
|
|
|
return 1;
|
|
|
|
}
|
|
|
|
unsigned depth = 0;
|
2019-01-31 23:16:29 +08:00
|
|
|
for (auto &c : nestedPatterns) {
|
2019-01-26 22:59:23 +08:00
|
|
|
depth = std::max(depth, c.getDepth());
|
|
|
|
}
|
|
|
|
return depth + 1;
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Matches a single instruction in the following way:
|
|
|
|
/// 1. checks the kind of instruction against the matcher, if different then
|
|
|
|
/// there is no match;
|
|
|
|
/// 2. calls the customizable filter function to refine the single instruction
|
|
|
|
/// match with extra semantic constraints;
|
|
|
|
/// 3. if all is good, recursivey matches the nested patterns;
|
|
|
|
/// 4. if all nested match then the single instruction matches too and is
|
|
|
|
/// appended to the list of matches;
|
|
|
|
/// 5. TODO(ntv) Optionally applies actions (lambda), in which case we will
|
|
|
|
/// want to traverse in post-order DFS to avoid invalidating iterators.
|
2019-01-31 23:16:29 +08:00
|
|
|
void NestedPattern::matchOne(Instruction *inst,
|
|
|
|
SmallVectorImpl<NestedMatch> *matches) {
|
|
|
|
if (skip == inst) {
|
2019-01-26 22:59:23 +08:00
|
|
|
return;
|
|
|
|
}
|
|
|
|
// Structural filter
|
2019-01-31 23:16:29 +08:00
|
|
|
if (inst->getKind() != kind) {
|
2019-01-26 22:59:23 +08:00
|
|
|
return;
|
|
|
|
}
|
|
|
|
// Local custom filter function
|
2019-01-31 23:16:29 +08:00
|
|
|
if (!filter(*inst)) {
|
2019-01-26 22:59:23 +08:00
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
2019-01-31 23:16:29 +08:00
|
|
|
if (nestedPatterns.empty()) {
|
|
|
|
SmallVector<NestedMatch, 8> nestedMatches;
|
|
|
|
matches->push_back(NestedMatch::build(inst, nestedMatches));
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
// Take a copy of each nested pattern so we can match it.
|
|
|
|
for (auto nestedPattern : nestedPatterns) {
|
|
|
|
SmallVector<NestedMatch, 8> nestedMatches;
|
2019-01-26 22:59:23 +08:00
|
|
|
// Skip elem in the walk immediately following. Without this we would
|
|
|
|
// essentially need to reimplement walkPostOrder here.
|
2019-01-31 23:16:29 +08:00
|
|
|
nestedPattern.skip = inst;
|
|
|
|
nestedPattern.match(inst, &nestedMatches);
|
|
|
|
// If we could not match even one of the specified nestedPattern, early exit
|
|
|
|
// as this whole branch is not a match.
|
|
|
|
if (nestedMatches.empty()) {
|
2019-01-26 22:59:23 +08:00
|
|
|
return;
|
|
|
|
}
|
2019-01-31 23:16:29 +08:00
|
|
|
matches->push_back(NestedMatch::build(inst, nestedMatches));
|
2019-01-26 22:59:23 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2019-02-02 08:42:18 +08:00
|
|
|
static bool isAffineForOp(const Instruction &inst) {
|
|
|
|
return cast<OperationInst>(inst).isa<AffineForOp>();
|
|
|
|
}
|
|
|
|
|
2019-01-29 13:23:53 +08:00
|
|
|
static bool isAffineIfOp(const Instruction &inst) {
|
|
|
|
return isa<OperationInst>(inst) &&
|
|
|
|
cast<OperationInst>(inst).isa<AffineIfOp>();
|
|
|
|
}
|
|
|
|
|
2019-01-26 22:59:23 +08:00
|
|
|
namespace mlir {
|
|
|
|
namespace matcher {
|
|
|
|
|
|
|
|
NestedPattern Op(FilterFunctionType filter) {
|
|
|
|
return NestedPattern(Instruction::Kind::OperationInst, {}, filter);
|
|
|
|
}
|
|
|
|
|
|
|
|
NestedPattern If(NestedPattern child) {
|
2019-01-29 13:23:53 +08:00
|
|
|
return NestedPattern(Instruction::Kind::OperationInst, child, isAffineIfOp);
|
2019-01-26 22:59:23 +08:00
|
|
|
}
|
|
|
|
NestedPattern If(FilterFunctionType filter, NestedPattern child) {
|
2019-01-29 13:23:53 +08:00
|
|
|
return NestedPattern(Instruction::Kind::OperationInst, child,
|
|
|
|
[filter](const Instruction &inst) {
|
|
|
|
return isAffineIfOp(inst) && filter(inst);
|
|
|
|
});
|
2019-01-26 22:59:23 +08:00
|
|
|
}
|
|
|
|
NestedPattern If(ArrayRef<NestedPattern> nested) {
|
2019-01-29 13:23:53 +08:00
|
|
|
return NestedPattern(Instruction::Kind::OperationInst, nested, isAffineIfOp);
|
2019-01-26 22:59:23 +08:00
|
|
|
}
|
|
|
|
NestedPattern If(FilterFunctionType filter, ArrayRef<NestedPattern> nested) {
|
2019-01-29 13:23:53 +08:00
|
|
|
return NestedPattern(Instruction::Kind::OperationInst, nested,
|
|
|
|
[filter](const Instruction &inst) {
|
|
|
|
return isAffineIfOp(inst) && filter(inst);
|
|
|
|
});
|
2019-01-26 22:59:23 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
NestedPattern For(NestedPattern child) {
|
2019-02-02 08:42:18 +08:00
|
|
|
return NestedPattern(Instruction::Kind::OperationInst, child, isAffineForOp);
|
2019-01-26 22:59:23 +08:00
|
|
|
}
|
|
|
|
NestedPattern For(FilterFunctionType filter, NestedPattern child) {
|
2019-02-02 08:42:18 +08:00
|
|
|
return NestedPattern(Instruction::Kind::OperationInst, child,
|
|
|
|
[=](const Instruction &inst) {
|
|
|
|
return isAffineForOp(inst) && filter(inst);
|
|
|
|
});
|
2019-01-26 22:59:23 +08:00
|
|
|
}
|
|
|
|
NestedPattern For(ArrayRef<NestedPattern> nested) {
|
2019-02-02 08:42:18 +08:00
|
|
|
return NestedPattern(Instruction::Kind::OperationInst, nested, isAffineForOp);
|
2019-01-26 22:59:23 +08:00
|
|
|
}
|
|
|
|
NestedPattern For(FilterFunctionType filter, ArrayRef<NestedPattern> nested) {
|
2019-02-02 08:42:18 +08:00
|
|
|
return NestedPattern(Instruction::Kind::OperationInst, nested,
|
|
|
|
[=](const Instruction &inst) {
|
|
|
|
return isAffineForOp(inst) && filter(inst);
|
|
|
|
});
|
2019-01-26 22:59:23 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
// TODO(ntv): parallel annotation on loops.
|
|
|
|
bool isParallelLoop(const Instruction &inst) {
|
2019-02-02 08:42:18 +08:00
|
|
|
auto loop = cast<OperationInst>(inst).cast<AffineForOp>();
|
|
|
|
return loop || true; // loop->isParallel();
|
2019-01-26 22:59:23 +08:00
|
|
|
};
|
|
|
|
|
|
|
|
// TODO(ntv): reduction annotation on loops.
|
|
|
|
bool isReductionLoop(const Instruction &inst) {
|
2019-02-02 08:42:18 +08:00
|
|
|
auto loop = cast<OperationInst>(inst).cast<AffineForOp>();
|
|
|
|
return loop || true; // loop->isReduction();
|
2019-01-26 22:59:23 +08:00
|
|
|
};
|
|
|
|
|
|
|
|
bool isLoadOrStore(const Instruction &inst) {
|
|
|
|
const auto *opInst = dyn_cast<OperationInst>(&inst);
|
|
|
|
return opInst && (opInst->isa<LoadOp>() || opInst->isa<StoreOp>());
|
|
|
|
};
|
|
|
|
|
|
|
|
} // end namespace matcher
|
|
|
|
} // end namespace mlir
|