forked from OSchip/llvm-project
250 lines
9.6 KiB
C++
250 lines
9.6 KiB
C++
//===- ReductionTreePass.cpp - ReductionTreePass Implementation -----------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file defines the Reduction Tree Pass class. It provides a framework for
|
|
// the implementation of different reduction passes in the MLIR Reduce tool. It
|
|
// allows for custom specification of the variant generation behavior. It
|
|
// implements methods that define the different possible traversals of the
|
|
// reduction tree.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/IR/DialectInterface.h"
|
|
#include "mlir/IR/OpDefinition.h"
|
|
#include "mlir/Reducer/PassDetail.h"
|
|
#include "mlir/Reducer/Passes.h"
|
|
#include "mlir/Reducer/ReductionNode.h"
|
|
#include "mlir/Reducer/ReductionPatternInterface.h"
|
|
#include "mlir/Reducer/Tester.h"
|
|
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
|
|
#include "llvm/ADT/ArrayRef.h"
|
|
#include "llvm/ADT/SmallVector.h"
|
|
#include "llvm/Support/Allocator.h"
|
|
#include "llvm/Support/ManagedStatic.h"
|
|
|
|
using namespace mlir;
|
|
|
|
/// We implicitly number each operation in the region and if an operation's
|
|
/// number falls into rangeToKeep, we need to keep it and apply the given
|
|
/// rewrite patterns on it.
|
|
static void applyPatterns(Region ®ion,
|
|
const FrozenRewritePatternSet &patterns,
|
|
ArrayRef<ReductionNode::Range> rangeToKeep,
|
|
bool eraseOpNotInRange) {
|
|
std::vector<Operation *> opsNotInRange;
|
|
std::vector<Operation *> opsInRange;
|
|
size_t keepIndex = 0;
|
|
for (const auto &op : enumerate(region.getOps())) {
|
|
int index = op.index();
|
|
if (keepIndex < rangeToKeep.size() &&
|
|
index == rangeToKeep[keepIndex].second)
|
|
++keepIndex;
|
|
if (keepIndex == rangeToKeep.size() || index < rangeToKeep[keepIndex].first)
|
|
opsNotInRange.push_back(&op.value());
|
|
else
|
|
opsInRange.push_back(&op.value());
|
|
}
|
|
|
|
// `applyOpPatternsAndFold` may erase the ops so we can't do the pattern
|
|
// matching in above iteration. Besides, erase op not-in-range may end up in
|
|
// invalid module, so `applyOpPatternsAndFold` should come before that
|
|
// transform.
|
|
for (Operation *op : opsInRange)
|
|
// `applyOpPatternsAndFold` returns whether the op is convered. Omit it
|
|
// because we don't have expectation this reduction will be success or not.
|
|
(void)applyOpPatternsAndFold(op, patterns);
|
|
|
|
if (eraseOpNotInRange)
|
|
for (Operation *op : opsNotInRange) {
|
|
op->dropAllUses();
|
|
op->erase();
|
|
}
|
|
}
|
|
|
|
/// We will apply the reducer patterns to the operations in the ranges specified
|
|
/// by ReductionNode. Note that we are not able to remove an operation without
|
|
/// replacing it with another valid operation. However, The validity of module
|
|
/// reduction is based on the Tester provided by the user and that means certain
|
|
/// invalid module is still interested by the use. Thus we provide an
|
|
/// alternative way to remove operations, which is using `eraseOpNotInRange` to
|
|
/// erase the operations not in the range specified by ReductionNode.
|
|
template <typename IteratorType>
|
|
static LogicalResult findOptimal(ModuleOp module, Region ®ion,
|
|
const FrozenRewritePatternSet &patterns,
|
|
const Tester &test, bool eraseOpNotInRange) {
|
|
std::pair<Tester::Interestingness, size_t> initStatus =
|
|
test.isInteresting(module);
|
|
// While exploring the reduction tree, we always branch from an interesting
|
|
// node. Thus the root node must be interesting.
|
|
if (initStatus.first != Tester::Interestingness::True)
|
|
return module.emitWarning() << "uninterested module will not be reduced";
|
|
|
|
llvm::SpecificBumpPtrAllocator<ReductionNode> allocator;
|
|
|
|
std::vector<ReductionNode::Range> ranges{
|
|
{0, std::distance(region.op_begin(), region.op_end())}};
|
|
|
|
ReductionNode *root = allocator.Allocate();
|
|
new (root) ReductionNode(nullptr, ranges, allocator);
|
|
// Duplicate the module for root node and locate the region in the copy.
|
|
if (failed(root->initialize(module, region)))
|
|
llvm_unreachable("unexpected initialization failure");
|
|
root->update(initStatus);
|
|
|
|
ReductionNode *smallestNode = root;
|
|
IteratorType iter(root);
|
|
|
|
while (iter != IteratorType::end()) {
|
|
ReductionNode ¤tNode = *iter;
|
|
Region &curRegion = currentNode.getRegion();
|
|
|
|
applyPatterns(curRegion, patterns, currentNode.getRanges(),
|
|
eraseOpNotInRange);
|
|
currentNode.update(test.isInteresting(currentNode.getModule()));
|
|
|
|
if (currentNode.isInteresting() == Tester::Interestingness::True &&
|
|
currentNode.getSize() < smallestNode->getSize())
|
|
smallestNode = ¤tNode;
|
|
|
|
++iter;
|
|
}
|
|
|
|
// At here, we have found an optimal path to reduce the given region. Retrieve
|
|
// the path and apply the reducer to it.
|
|
SmallVector<ReductionNode *> trace;
|
|
ReductionNode *curNode = smallestNode;
|
|
trace.push_back(curNode);
|
|
while (curNode != root) {
|
|
curNode = curNode->getParent();
|
|
trace.push_back(curNode);
|
|
}
|
|
|
|
// Reduce the region through the optimal path.
|
|
while (!trace.empty()) {
|
|
ReductionNode *top = trace.pop_back_val();
|
|
applyPatterns(region, patterns, top->getStartRanges(), eraseOpNotInRange);
|
|
}
|
|
|
|
if (test.isInteresting(module).first != Tester::Interestingness::True)
|
|
llvm::report_fatal_error("Reduced module is not interesting");
|
|
if (test.isInteresting(module).second != smallestNode->getSize())
|
|
llvm::report_fatal_error(
|
|
"Reduced module doesn't have consistent size with smallestNode");
|
|
return success();
|
|
}
|
|
|
|
template <typename IteratorType>
|
|
static LogicalResult findOptimal(ModuleOp module, Region ®ion,
|
|
const FrozenRewritePatternSet &patterns,
|
|
const Tester &test) {
|
|
// We separate the reduction process into 2 steps, the first one is to erase
|
|
// redundant operations and the second one is to apply the reducer patterns.
|
|
|
|
// In the first phase, we don't apply any patterns so that we only select the
|
|
// range of operations to keep to the module stay interesting.
|
|
if (failed(findOptimal<IteratorType>(module, region, /*patterns=*/{}, test,
|
|
/*eraseOpNotInRange=*/true)))
|
|
return failure();
|
|
// In the second phase, we suppose that no operation is redundant, so we try
|
|
// to rewrite the operation into simpler form.
|
|
return findOptimal<IteratorType>(module, region, patterns, test,
|
|
/*eraseOpNotInRange=*/false);
|
|
}
|
|
|
|
namespace {
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Reduction Pattern Interface Collection
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
class ReductionPatternInterfaceCollection
|
|
: public DialectInterfaceCollection<DialectReductionPatternInterface> {
|
|
public:
|
|
using Base::Base;
|
|
|
|
// Collect the reduce patterns defined by each dialect.
|
|
void populateReductionPatterns(RewritePatternSet &pattern) const {
|
|
for (const DialectReductionPatternInterface &interface : *this)
|
|
interface.populateReductionPatterns(pattern);
|
|
}
|
|
};
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ReductionTreePass
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// This class defines the Reduction Tree Pass. It provides a framework to
|
|
/// to implement a reduction pass using a tree structure to keep track of the
|
|
/// generated reduced variants.
|
|
class ReductionTreePass : public ReductionTreeBase<ReductionTreePass> {
|
|
public:
|
|
ReductionTreePass() = default;
|
|
ReductionTreePass(const ReductionTreePass &pass) = default;
|
|
|
|
LogicalResult initialize(MLIRContext *context) override;
|
|
|
|
/// Runs the pass instance in the pass pipeline.
|
|
void runOnOperation() override;
|
|
|
|
private:
|
|
LogicalResult reduceOp(ModuleOp module, Region ®ion);
|
|
|
|
FrozenRewritePatternSet reducerPatterns;
|
|
};
|
|
|
|
} // namespace
|
|
|
|
LogicalResult ReductionTreePass::initialize(MLIRContext *context) {
|
|
RewritePatternSet patterns(context);
|
|
ReductionPatternInterfaceCollection reducePatternCollection(context);
|
|
reducePatternCollection.populateReductionPatterns(patterns);
|
|
reducerPatterns = std::move(patterns);
|
|
return success();
|
|
}
|
|
|
|
void ReductionTreePass::runOnOperation() {
|
|
Operation *topOperation = getOperation();
|
|
while (topOperation->getParentOp() != nullptr)
|
|
topOperation = topOperation->getParentOp();
|
|
ModuleOp module = cast<ModuleOp>(topOperation);
|
|
|
|
SmallVector<Operation *, 8> workList;
|
|
workList.push_back(getOperation());
|
|
|
|
do {
|
|
Operation *op = workList.pop_back_val();
|
|
|
|
for (Region ®ion : op->getRegions())
|
|
if (!region.empty())
|
|
if (failed(reduceOp(module, region)))
|
|
return signalPassFailure();
|
|
|
|
for (Region ®ion : op->getRegions())
|
|
for (Operation &op : region.getOps())
|
|
if (op.getNumRegions() != 0)
|
|
workList.push_back(&op);
|
|
} while (!workList.empty());
|
|
}
|
|
|
|
LogicalResult ReductionTreePass::reduceOp(ModuleOp module, Region ®ion) {
|
|
Tester test(testerName, testerArgs);
|
|
switch (traversalModeId) {
|
|
case TraversalMode::SinglePath:
|
|
return findOptimal<ReductionNode::iterator<TraversalMode::SinglePath>>(
|
|
module, region, reducerPatterns, test);
|
|
default:
|
|
return module.emitError() << "unsupported traversal mode detected";
|
|
}
|
|
}
|
|
|
|
std::unique_ptr<Pass> mlir::createReductionTreePass() {
|
|
return std::make_unique<ReductionTreePass>();
|
|
}
|