[mlir-reduce] Reducer refactor.

* A Reducer is a kind of RewritePattern, so it's just the same as
writing graph rewrite.
* ReductionTreePass operates on Operation rather than ModuleOp, so that
* we are able to reduce a nested structure(e.g., module in module) by
* self-nesting.

Reviewed By: jpienaar, rriddle

Differential Revision: https://reviews.llvm.org/D101046
This commit is contained in:
Chia-hung Duan 2021-06-02 07:00:19 +08:00
parent 26044c6a54
commit c484c7dd9d
26 changed files with 519 additions and 371 deletions

View File

@ -1,41 +0,0 @@
//===- OptReductionPass.h - Optimization Reduction Pass Wrapper -*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the Opt Reduction Pass Wrapper. It creates a MLIR pass to
// run any optimization pass within it and only replaces the output module with
// the transformed version if it is smaller and interesting.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_REDUCER_OPTREDUCTIONPASS_H
#define MLIR_REDUCER_OPTREDUCTIONPASS_H
#include "PassDetail.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Reducer/ReductionNode.h"
#include "mlir/Reducer/ReductionTreePass.h"
#include "mlir/Reducer/Tester.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/Support/Debug.h"
namespace mlir {
class OptReductionPass : public OptReductionBase<OptReductionPass> {
public:
OptReductionPass() = default;
OptReductionPass(const OptReductionPass &srcPass) = default;
/// Runs the pass instance in the pass pipeline.
void runOnOperation() override;
};
} // end namespace mlir
#endif

View File

@ -9,8 +9,6 @@
#define MLIR_REDUCER_PASSES_H
#include "mlir/Pass/Pass.h"
#include "mlir/Reducer/OptReductionPass.h"
#include "mlir/Reducer/ReductionTreePass.h"
namespace mlir {

View File

@ -24,14 +24,12 @@ def CommonReductionPassOptions {
];
}
def ReductionTree : Pass<"reduction-tree", "ModuleOp"> {
def ReductionTree : Pass<"reduction-tree"> {
let summary = "A general reduction tree pass for the MLIR Reduce Tool";
let constructor = "mlir::createReductionTreePass()";
let options = [
Option<"opReducerName", "op-reducer", "std::string", /* default */"",
"The OpReducer to reduce the module">,
Option<"traversalModeId", "traversal-mode", "unsigned",
/* default */"0", "The graph traversal mode">,
] # CommonReductionPassOptions.options;

View File

@ -1,76 +0,0 @@
//===- OpReducer.h - MLIR Reduce Operation Reducer ------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the OpReducer class. It defines a variant generator method
// with the purpose of producing different variants by eliminating a
// parameterizable type of operations from the parent module.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_REDUCER_PASSES_OPREDUCER_H
#define MLIR_REDUCER_PASSES_OPREDUCER_H
#include <limits>
#include "mlir/Reducer/ReductionNode.h"
#include "mlir/Reducer/Tester.h"
namespace mlir {
class OpReducer {
public:
virtual ~OpReducer() = default;
/// According to rangeToKeep, try to reduce the given module. We implicitly
/// number each interesting operation and rangeToKeep indicates that if an
/// operation's number falls into certain range, then we will not try to
/// reduce that operation.
virtual void reduce(ModuleOp module,
ArrayRef<ReductionNode::Range> rangeToKeep) = 0;
/// Return the number of certain kind of operations that we would like to
/// reduce. This can be used to build a range map to exclude uninterested
/// operations.
virtual int getNumTargetOps(ModuleOp module) const = 0;
};
/// Reducer is a helper class to remove potential uninteresting operations from
/// module.
template <typename OpType>
class Reducer : public OpReducer {
public:
~Reducer() override = default;
int getNumTargetOps(ModuleOp module) const override {
return std::distance(module.getOps<OpType>().begin(),
module.getOps<OpType>().end());
}
void reduce(ModuleOp module,
ArrayRef<ReductionNode::Range> rangeToKeep) override {
std::vector<Operation *> opsToRemove;
size_t keepIndex = 0;
for (auto op : enumerate(module.getOps<OpType>())) {
int index = op.index();
if (keepIndex < rangeToKeep.size() &&
index == rangeToKeep[keepIndex].second)
++keepIndex;
if (keepIndex == rangeToKeep.size() ||
index < rangeToKeep[keepIndex].first)
opsToRemove.push_back(op.value());
}
for (Operation *o : opsToRemove) {
o->dropAllUses();
o->erase();
}
}
};
} // end namespace mlir
#endif

View File

@ -21,19 +21,25 @@
#include <vector>
#include "mlir/Reducer/Tester.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/Support/Allocator.h"
#include "llvm/Support/ToolOutputFile.h"
namespace mlir {
class ModuleOp;
class Region;
/// Defines the traversal method options to be used in the reduction tree
/// traversal.
enum TraversalMode { SinglePath, Backtrack, MultiPath };
/// This class defines the ReductionNode which is used to generate variant and
/// keep track of the necessary metadata for the reduction pass. The nodes are
/// linked together in a reduction tree structure which defines the relationship
/// between all the different generated variants.
/// ReductionTreePass will build a reduction tree during module reduction and
/// the ReductionNode represents the vertex of the tree. A ReductionNode records
/// the information such as the reduced module, how this node is reduced from
/// the parent node, etc. This information will be used to construct a reduction
/// path to reduce the certain module.
class ReductionNode {
public:
template <TraversalMode mode>
@ -44,23 +50,46 @@ public:
ReductionNode(ReductionNode *parent, std::vector<Range> range,
llvm::SpecificBumpPtrAllocator<ReductionNode> &allocator);
ReductionNode *getParent() const;
ReductionNode *getParent() const { return parent; }
size_t getSize() const;
/// If the ReductionNode hasn't been tested the interestingness, it'll be the
/// same module as the one in the parent node. Otherwise, the returned module
/// will have been applied certain reduction strategies. Note that it's not
/// necessary to be an interesting case or a reduced module (has smaller size
/// than parent's).
ModuleOp getModule() const { return module; }
/// Return the region we're reducing.
Region &getRegion() const { return *region; }
/// Return the size of the module.
size_t getSize() const { return size; }
/// Returns true if the module exhibits the interesting behavior.
Tester::Interestingness isInteresting() const;
Tester::Interestingness isInteresting() const { return interesting; }
std::vector<Range> getRanges() const;
/// Return the range information that how this node is reduced from the parent
/// node.
ArrayRef<Range> getStartRanges() const { return startRanges; }
std::vector<ReductionNode *> &getVariants();
/// Return the range set we are using to generate variants.
ArrayRef<Range> getRanges() const { return ranges; }
/// Return the generated variants(the child nodes).
ArrayRef<ReductionNode *> getVariants() const { return variants; }
/// Split the ranges and generate new variants.
std::vector<ReductionNode *> generateNewVariants();
ArrayRef<ReductionNode *> generateNewVariants();
/// Update the interestingness result from tester.
void update(std::pair<Tester::Interestingness, size_t> result);
/// Each Reduction Node contains a copy of module for applying rewrite
/// patterns. In addition, we only apply rewrite patterns in a certain region.
/// In init(), we will duplicate the module from parent node and locate the
/// corresponding region.
LogicalResult initialize(ModuleOp parentModule, Region &parentRegion);
private:
/// A custom BFS iterator. The difference between
/// llvm/ADT/BreadthFirstIterator.h is the graph we're exploring is dynamic.
@ -87,8 +116,7 @@ private:
BaseIterator &operator++() {
ReductionNode *top = visitQueue.front();
visitQueue.pop();
std::vector<ReductionNode *> neighbors = getNeighbors(top);
for (ReductionNode *node : neighbors)
for (ReductionNode *node : getNeighbors(top))
visitQueue.push(node);
return *this;
}
@ -103,7 +131,7 @@ private:
ReductionNode *operator->() const { return visitQueue.front(); }
protected:
std::vector<ReductionNode *> getNeighbors(ReductionNode *node) {
ArrayRef<ReductionNode *> getNeighbors(ReductionNode *node) {
return static_cast<T *>(this)->getNeighbors(node);
}
@ -111,21 +139,42 @@ private:
std::queue<ReductionNode *> visitQueue;
};
/// The size of module after applying the range constraints.
/// This is a copy of module from parent node. All the reducer patterns will
/// be applied to this instance.
ModuleOp module;
/// The region of certain operation we're reducing in the module
Region *region;
/// The node we are reduced from. It means we will be in variants of parent
/// node.
ReductionNode *parent;
/// The size of module after applying the reducer patterns with range
/// constraints. This is only valid while the interestingness has been tested.
size_t size;
/// This is true if the module has been evaluated and it exhibits the
/// interesting behavior.
Tester::Interestingness interesting;
ReductionNode *parent;
/// We will only keep the operation with index falls into the ranges.
/// For example, number each function in a certain module and then we will
/// remove the functions with index outside the ranges and see if the
/// resulting module is still interesting.
/// `ranges` represents the selected subset of operations in the region. We
/// implictly number each operation in the region and ReductionTreePass will
/// apply reducer patterns on the operation falls into the `ranges`. We will
/// generate new ReductionNode with subset of `ranges` to see if we can do
/// further reduction. we may split the element in the `ranges` so that we can
/// have more subset variants from `ranges`.
/// Note that after applying the reducer patterns the number of operation in
/// the region may have changed, we need to update the `ranges` after that.
std::vector<Range> ranges;
/// `startRanges` records the ranges of operations selected from the parent
/// node to produce this ReductionNode. It can be used to construct the
/// reduction path from the root. I.e., if we apply the same reducer patterns
/// and `startRanges` selection on the parent region, we will get the same
/// module as this node.
const std::vector<Range> startRanges;
/// This points to the child variants that were created using this node as a
/// starting point.
std::vector<ReductionNode *> variants;
@ -139,9 +188,9 @@ class ReductionNode::iterator<SinglePath>
: public BaseIterator<iterator<SinglePath>> {
friend BaseIterator<iterator<SinglePath>>;
using BaseIterator::BaseIterator;
std::vector<ReductionNode *> getNeighbors(ReductionNode *node);
ArrayRef<ReductionNode *> getNeighbors(ReductionNode *node);
};
} // end namespace mlir
#endif
#endif // MLIR_REDUCER_REDUCTIONNODE_H

View File

@ -0,0 +1,56 @@
//===- ReducePatternInterface.h - Collecting Reduce Patterns ----*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_REDUCER_REDUCTIONPATTERNINTERFACE_H
#define MLIR_REDUCER_REDUCTIONPATTERNINTERFACE_H
#include "mlir/IR/DialectInterface.h"
namespace mlir {
class RewritePatternSet;
/// This is used to report the reduction patterns for a Dialect. While using
/// mlir-reduce to reduce a module, we may want to transform certain cases into
/// simpler forms by applying certain rewrite patterns. Implement the
/// `populateReductionPatterns` to report those patterns by adding them to the
/// RewritePatternSet.
///
/// Example:
/// MyDialectReductionPattern::populateReductionPatterns(
/// RewritePatternSet &patterns) {
/// patterns.add<TensorOpReduction>(patterns.getContext());
/// }
///
/// For DRR, mlir-tblgen will generate a helper function
/// `populateWithGenerated` which has the same signature therefore you can
/// delegate to the helper function as well.
///
/// Example:
/// MyDialectReductionPattern::populateReductionPatterns(
/// RewritePatternSet &patterns) {
/// // Include the autogen file somewhere above.
/// populateWithGenerated(patterns);
/// }
class DialectReductionPatternInterface
: public DialectInterface::Base<DialectReductionPatternInterface> {
public:
/// Patterns provided here are intended to transform operations from a complex
/// form to a simpler form, without breaking the semantics of the program
/// being reduced. For example, you may want to replace the
/// tensor<?xindex> with a known rank and type, e.g. tensor<1xi32>, or
/// replacing an operation with a constant.
virtual void populateReductionPatterns(RewritePatternSet &patterns) const = 0;
protected:
DialectReductionPatternInterface(Dialect *dialect) : Base(dialect) {}
};
} // end namespace mlir
#endif // MLIR_REDUCER_REDUCTIONPATTERNINTERFACE_H

View File

@ -1,50 +0,0 @@
//===- ReductionTreePass.h - Reduction Tree Pass Implementation -*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the Reduction Tree Pass class. It provides a framework for
// the implementation of different reduction passes in the MLIR Reduce tool. It
// allows for custom specification of the variant generation behavior. It
// implements methods that define the different possible traversals of the
// reduction tree.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_REDUCER_REDUCTIONTREEPASS_H
#define MLIR_REDUCER_REDUCTIONTREEPASS_H
#include <vector>
#include "PassDetail.h"
#include "ReductionNode.h"
#include "mlir/Reducer/Passes/OpReducer.h"
#include "mlir/Reducer/Tester.h"
#define DEBUG_TYPE "mlir-reduce"
namespace mlir {
/// This class defines the Reduction Tree Pass. It provides a framework to
/// to implement a reduction pass using a tree structure to keep track of the
/// generated reduced variants.
class ReductionTreePass : public ReductionTreeBase<ReductionTreePass> {
public:
ReductionTreePass() = default;
ReductionTreePass(const ReductionTreePass &pass) = default;
/// Runs the pass instance in the pass pipeline.
void runOnOperation() override;
private:
template <typename IteratorType>
ModuleOp findOptimal(ModuleOp module, std::unique_ptr<OpReducer> reducer,
ReductionNode *node);
};
} // end namespace mlir
#endif

View File

@ -1,7 +1,13 @@
add_mlir_library(MLIRReduce
OptReductionPass.cpp
ReductionNode.cpp
ReductionTreePass.cpp
Tester.cpp
LINK_LIBS PUBLIC
MLIRIR
MLIRPass
MLIRRewrite
MLIRTransformUtils
)
mlir_check_all_link_libraries(MLIRReduce)

View File

@ -12,15 +12,27 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/Reducer/OptReductionPass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Reducer/PassDetail.h"
#include "mlir/Reducer/Passes.h"
#include "mlir/Reducer/Tester.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "mlir-reduce"
using namespace mlir;
namespace {
class OptReductionPass : public OptReductionBase<OptReductionPass> {
public:
/// Runs the pass instance in the pass pipeline.
void runOnOperation() override;
};
} // end anonymous namespace
/// Runs the pass instance in the pass pipeline.
void OptReductionPass::runOnOperation() {
LLVM_DEBUG(llvm::dbgs() << "\nOptimization Reduction pass: ");

View File

@ -15,6 +15,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Reducer/ReductionNode.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "llvm/ADT/STLExtras.h"
#include <algorithm>
@ -23,102 +24,102 @@
using namespace mlir;
ReductionNode::ReductionNode(
ReductionNode *parent, std::vector<Range> ranges,
ReductionNode *parentNode, std::vector<Range> ranges,
llvm::SpecificBumpPtrAllocator<ReductionNode> &allocator)
: size(std::numeric_limits<size_t>::max()),
interesting(Tester::Interestingness::Untested),
/// Root node will have the parent pointer point to themselves.
parent(parent == nullptr ? this : parent), ranges(ranges),
allocator(allocator) {}
/// Returns the size in bytes of the module.
size_t ReductionNode::getSize() const { return size; }
ReductionNode *ReductionNode::getParent() const { return parent; }
/// Returns true if the module exhibits the interesting behavior.
Tester::Interestingness ReductionNode::isInteresting() const {
return interesting;
/// Root node will have the parent pointer point to themselves.
: parent(parentNode == nullptr ? this : parentNode),
size(std::numeric_limits<size_t>::max()),
interesting(Tester::Interestingness::Untested), ranges(ranges),
startRanges(ranges), allocator(allocator) {
if (parent != this)
if (failed(initialize(parent->getModule(), parent->getRegion())))
llvm_unreachable("unexpected initialization failure");
}
std::vector<ReductionNode::Range> ReductionNode::getRanges() const {
return ranges;
LogicalResult ReductionNode::initialize(ModuleOp parentModule,
Region &targetRegion) {
// Use the mapper help us find the corresponding region after module clone.
BlockAndValueMapping mapper;
module = cast<ModuleOp>(parentModule->clone(mapper));
// Use the first block of targetRegion to locate the cloned region.
Block *block = mapper.lookup(&*targetRegion.begin());
region = block->getParent();
return success();
}
std::vector<ReductionNode *> &ReductionNode::getVariants() { return variants; }
#include <iostream>
/// If we haven't explored any variants from this node, we will create N
/// variants, N is the length of `ranges` if N > 1. Otherwise, we will split the
/// max element in `ranges` and create 2 new variants for each call.
std::vector<ReductionNode *> ReductionNode::generateNewVariants() {
std::vector<ReductionNode *> newNodes;
ArrayRef<ReductionNode *> ReductionNode::generateNewVariants() {
int oldNumVariant = getVariants().size();
auto createNewNode = [this](std::vector<Range> ranges) {
return new (allocator.Allocate())
ReductionNode(this, std::move(ranges), allocator);
};
// If we haven't created new variant, then we can create varients by removing
// each of them respectively. For example, given {{1, 3}, {4, 9}}, we can
// produce variants with range {{1, 3}} and {{4, 9}}.
if (variants.size() == 0 && ranges.size() != 1) {
for (const Range &range : ranges) {
std::vector<Range> subRanges = ranges;
if (variants.size() == 0 && getRanges().size() > 1) {
for (const Range &range : getRanges()) {
std::vector<Range> subRanges = getRanges();
llvm::erase_value(subRanges, range);
ReductionNode *newNode = allocator.Allocate();
new (newNode) ReductionNode(this, subRanges, allocator);
newNodes.push_back(newNode);
variants.push_back(newNode);
variants.push_back(createNewNode(std::move(subRanges)));
}
return newNodes;
return getVariants().drop_front(oldNumVariant);
}
// At here, we have created the type of variants mentioned above. We would
// like to split the max range into 2 to create 2 new variants. Continue on
// the above example, we split the range {4, 9} into {4, 6}, {6, 9}, and
// create two variants with range {{1, 3}, {4, 6}} and {{1, 3}, {6, 9}}. The
// result ranges vector will be {{1, 3}, {4, 6}, {6, 9}}.
// final ranges vector will be {{1, 3}, {4, 6}, {6, 9}}.
auto maxElement = std::max_element(
ranges.begin(), ranges.end(), [](const Range &lhs, const Range &rhs) {
return (lhs.second - lhs.first) > (rhs.second - rhs.first);
});
// We can't split range with lenght 1, which means we can't produce new
// The length of range is less than 1, we can't split it to create new
// variant.
if (maxElement->second - maxElement->first == 1)
if (maxElement->second - maxElement->first <= 1)
return {};
auto createNewNode = [this](const std::vector<Range> &ranges) {
ReductionNode *newNode = allocator.Allocate();
new (newNode) ReductionNode(this, ranges, allocator);
return newNode;
};
Range maxRange = *maxElement;
std::vector<Range> subRanges = ranges;
std::vector<Range> subRanges = getRanges();
auto subRangesIter = subRanges.begin() + (maxElement - ranges.begin());
int half = (maxRange.first + maxRange.second) / 2;
*subRangesIter = std::make_pair(maxRange.first, half);
newNodes.push_back(createNewNode(subRanges));
variants.push_back(createNewNode(subRanges));
*subRangesIter = std::make_pair(half, maxRange.second);
newNodes.push_back(createNewNode(subRanges));
variants.push_back(createNewNode(std::move(subRanges)));
variants.insert(variants.end(), newNodes.begin(), newNodes.end());
auto it = ranges.insert(maxElement, std::make_pair(half, maxRange.second));
it = ranges.insert(it, std::make_pair(maxRange.first, half));
// Remove the range that has been split.
ranges.erase(it + 2);
return newNodes;
return getVariants().drop_front(oldNumVariant);
}
void ReductionNode::update(std::pair<Tester::Interestingness, size_t> result) {
std::tie(interesting, size) = result;
// After applying reduction, the number of operation in the region may have
// changed. Non-interesting case won't be explored thus it's safe to keep it
// in a stale status.
if (interesting == Tester::Interestingness::True) {
// This module may has been updated. Reset the range.
ranges.clear();
ranges.push_back({0, std::distance(region->op_begin(), region->op_end())});
}
}
std::vector<ReductionNode *>
ArrayRef<ReductionNode *>
ReductionNode::iterator<SinglePath>::getNeighbors(ReductionNode *node) {
// Single Path: Traverses the smallest successful variant at each level until
// no new successful variants can be created at that level.
llvm::ArrayRef<ReductionNode *> variantsFromParent =
ArrayRef<ReductionNode *> variantsFromParent =
node->getParent()->getVariants();
// The parent node created several variants and they may be waiting for
@ -139,7 +140,8 @@ ReductionNode::iterator<SinglePath>::getNeighbors(ReductionNode *node) {
smallest = node;
}
if (smallest != nullptr) {
if (smallest != nullptr &&
smallest->getSize() < node->getParent()->getSize()) {
// We got a smallest one, keep traversing from this node.
node = smallest;
} else {

View File

@ -0,0 +1,247 @@
//===- ReductionTreePass.cpp - ReductionTreePass Implementation -----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the Reduction Tree Pass class. It provides a framework for
// the implementation of different reduction passes in the MLIR Reduce tool. It
// allows for custom specification of the variant generation behavior. It
// implements methods that define the different possible traversals of the
// reduction tree.
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/DialectInterface.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Reducer/PassDetail.h"
#include "mlir/Reducer/Passes.h"
#include "mlir/Reducer/ReductionNode.h"
#include "mlir/Reducer/ReductionPatternInterface.h"
#include "mlir/Reducer/Tester.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Allocator.h"
#include "llvm/Support/ManagedStatic.h"
using namespace mlir;
/// We implicitly number each operation in the region and if an operation's
/// number falls into rangeToKeep, we need to keep it and apply the given
/// rewrite patterns on it.
static void applyPatterns(Region &region,
const FrozenRewritePatternSet &patterns,
ArrayRef<ReductionNode::Range> rangeToKeep,
bool eraseOpNotInRange) {
std::vector<Operation *> opsNotInRange;
std::vector<Operation *> opsInRange;
size_t keepIndex = 0;
for (auto op : enumerate(region.getOps())) {
int index = op.index();
if (keepIndex < rangeToKeep.size() &&
index == rangeToKeep[keepIndex].second)
++keepIndex;
if (keepIndex == rangeToKeep.size() || index < rangeToKeep[keepIndex].first)
opsNotInRange.push_back(&op.value());
else
opsInRange.push_back(&op.value());
}
// `applyOpPatternsAndFold` may erase the ops so we can't do the pattern
// matching in above iteration. Besides, erase op not-in-range may end up in
// invalid module, so `applyOpPatternsAndFold` should come before that
// transform.
for (Operation *op : opsInRange)
// `applyOpPatternsAndFold` returns whether the op is convered. Omit it
// because we don't have expectation this reduction will be success or not.
(void)applyOpPatternsAndFold(op, patterns);
if (eraseOpNotInRange)
for (Operation *op : opsNotInRange) {
op->dropAllUses();
op->erase();
}
}
/// We will apply the reducer patterns to the operations in the ranges specified
/// by ReductionNode. Note that we are not able to remove an operation without
/// replacing it with another valid operation. However, The validity of module
/// reduction is based on the Tester provided by the user and that means certain
/// invalid module is still interested by the use. Thus we provide an
/// alternative way to remove operations, which is using `eraseOpNotInRange` to
/// erase the operations not in the range specified by ReductionNode.
template <typename IteratorType>
static void findOptimal(ModuleOp module, Region &region,
const FrozenRewritePatternSet &patterns,
const Tester &test, bool eraseOpNotInRange) {
std::pair<Tester::Interestingness, size_t> initStatus =
test.isInteresting(module);
// While exploring the reduction tree, we always branch from an interesting
// node. Thus the root node must be interesting.
if (initStatus.first != Tester::Interestingness::True)
return;
llvm::SpecificBumpPtrAllocator<ReductionNode> allocator;
std::vector<ReductionNode::Range> ranges{
{0, std::distance(region.op_begin(), region.op_end())}};
ReductionNode *root = allocator.Allocate();
new (root) ReductionNode(nullptr, std::move(ranges), allocator);
// Duplicate the module for root node and locate the region in the copy.
if (failed(root->initialize(module, region)))
llvm_unreachable("unexpected initialization failure");
root->update(initStatus);
ReductionNode *smallestNode = root;
IteratorType iter(root);
while (iter != IteratorType::end()) {
ReductionNode &currentNode = *iter;
Region &curRegion = currentNode.getRegion();
applyPatterns(curRegion, patterns, currentNode.getRanges(),
eraseOpNotInRange);
currentNode.update(test.isInteresting(currentNode.getModule()));
if (currentNode.isInteresting() == Tester::Interestingness::True &&
currentNode.getSize() < smallestNode->getSize())
smallestNode = &currentNode;
++iter;
}
// At here, we have found an optimal path to reduce the given region. Retrieve
// the path and apply the reducer to it.
SmallVector<ReductionNode *> trace;
ReductionNode *curNode = smallestNode;
trace.push_back(curNode);
while (curNode != root) {
curNode = curNode->getParent();
trace.push_back(curNode);
}
// Reduce the region through the optimal path.
while (!trace.empty()) {
ReductionNode *top = trace.pop_back_val();
applyPatterns(region, patterns, top->getStartRanges(), eraseOpNotInRange);
}
if (test.isInteresting(module).first != Tester::Interestingness::True)
llvm::report_fatal_error("Reduced module is not interesting");
if (test.isInteresting(module).second != smallestNode->getSize())
llvm::report_fatal_error(
"Reduced module doesn't have consistent size with smallestNode");
}
template <typename IteratorType>
static void findOptimal(ModuleOp module, Region &region,
const FrozenRewritePatternSet &patterns,
const Tester &test) {
// We separate the reduction process into 2 steps, the first one is to erase
// redundant operations and the second one is to apply the reducer patterns.
// In the first phase, we don't apply any patterns so that we only select the
// range of operations to keep to the module stay interesting.
findOptimal<IteratorType>(module, region, /*patterns=*/{}, test,
/*eraseOpNotInRange=*/true);
// In the second phase, we suppose that no operation is redundant, so we try
// to rewrite the operation into simpler form.
findOptimal<IteratorType>(module, region, patterns, test,
/*eraseOpNotInRange=*/false);
}
namespace {
//===----------------------------------------------------------------------===//
// Reduction Pattern Interface Collection
//===----------------------------------------------------------------------===//
class ReductionPatternInterfaceCollection
: public DialectInterfaceCollection<DialectReductionPatternInterface> {
public:
using Base::Base;
// Collect the reduce patterns defined by each dialect.
void populateReductionPatterns(RewritePatternSet &pattern) const {
for (const DialectReductionPatternInterface &interface : *this)
interface.populateReductionPatterns(pattern);
}
};
//===----------------------------------------------------------------------===//
// ReductionTreePass
//===----------------------------------------------------------------------===//
/// This class defines the Reduction Tree Pass. It provides a framework to
/// to implement a reduction pass using a tree structure to keep track of the
/// generated reduced variants.
class ReductionTreePass : public ReductionTreeBase<ReductionTreePass> {
public:
ReductionTreePass() = default;
ReductionTreePass(const ReductionTreePass &pass) = default;
LogicalResult initialize(MLIRContext *context) override;
/// Runs the pass instance in the pass pipeline.
void runOnOperation() override;
private:
void reduceOp(ModuleOp module, Region &region);
FrozenRewritePatternSet reducerPatterns;
};
} // end anonymous namespace
LogicalResult ReductionTreePass::initialize(MLIRContext *context) {
RewritePatternSet patterns(context);
ReductionPatternInterfaceCollection reducePatternCollection(context);
reducePatternCollection.populateReductionPatterns(patterns);
reducerPatterns = std::move(patterns);
return success();
}
void ReductionTreePass::runOnOperation() {
Operation *topOperation = getOperation();
while (topOperation->getParentOp() != nullptr)
topOperation = topOperation->getParentOp();
ModuleOp module = cast<ModuleOp>(topOperation);
SmallVector<Operation *, 8> workList;
workList.push_back(getOperation());
do {
Operation *op = workList.pop_back_val();
for (Region &region : op->getRegions())
if (!region.empty())
reduceOp(module, region);
for (Region &region : op->getRegions())
for (Operation &op : region.getOps())
if (op.getNumRegions() != 0)
workList.push_back(&op);
} while (!workList.empty());
}
void ReductionTreePass::reduceOp(ModuleOp module, Region &region) {
Tester test(testerName, testerArgs);
switch (traversalModeId) {
case TraversalMode::SinglePath:
findOptimal<ReductionNode::iterator<TraversalMode::SinglePath>>(
module, region, reducerPatterns, test);
break;
default:
llvm_unreachable("Unsupported mode");
}
}
std::unique_ptr<Pass> mlir::createReductionTreePass() {
return std::make_unique<ReductionTreePass>();
}

View File

@ -15,7 +15,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Reducer/Tester.h"
#include "mlir/IR/Verifier.h"
#include "llvm/Support/ToolOutputFile.h"
using namespace mlir;
@ -25,6 +25,12 @@ Tester::Tester(StringRef scriptName, ArrayRef<std::string> scriptArgs)
std::pair<Tester::Interestingness, size_t>
Tester::isInteresting(ModuleOp module) const {
// The reduced module should always be vaild, or we may end up retaining the
// error message by an invalid case. Besides, an invalid module may not be
// able to print properly.
if (failed(verify(module)))
return std::make_pair(Interestingness::False, /*size=*/0);
SmallString<128> filepath;
int fd;
@ -50,7 +56,6 @@ Tester::isInteresting(ModuleOp module) const {
/// true if the interesting behavior is present in the test case or false
/// otherwise.
Tester::Interestingness Tester::isInteresting(StringRef testCase) const {
std::vector<StringRef> testerArgs;
testerArgs.push_back(testCase);

View File

@ -60,6 +60,7 @@ add_mlir_library(MLIRTestDialect
MLIRInferTypeOpInterface
MLIRLinalgTransforms
MLIRPass
MLIRReduce
MLIRStandard
MLIRStandardOpsTransforms
MLIRTransformUtils

View File

@ -8,6 +8,7 @@
#include "TestDialect.h"
#include "TestAttributes.h"
#include "TestInterfaces.h"
#include "TestTypes.h"
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
@ -16,6 +17,7 @@
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Reducer/ReductionPatternInterface.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/StringSwitch.h"
@ -170,6 +172,18 @@ struct TestInlinerInterface : public DialectInlinerInterface {
return builder.create<TestCastOp>(conversionLoc, resultType, input);
}
};
struct TestReductionPatternInterface : public DialectReductionPatternInterface {
public:
TestReductionPatternInterface(Dialect *dialect)
: DialectReductionPatternInterface(dialect) {}
virtual void
populateReductionPatterns(RewritePatternSet &patterns) const final {
populateTestReductionPatterns(patterns);
}
};
} // end anonymous namespace
//===----------------------------------------------------------------------===//
@ -207,7 +221,7 @@ void TestDialect::initialize() {
#include "TestOps.cpp.inc"
>();
addInterfaces<TestOpAsmInterface, TestDialectFoldInterface,
TestInlinerInterface>();
TestInlinerInterface, TestReductionPatternInterface>();
allowUnknownOperations();
// Instantiate our fallback op interface that we'll use on specific

View File

@ -34,6 +34,7 @@
namespace mlir {
class DLTIDialect;
class RewritePatternSet;
} // namespace mlir
#include "TestOpEnums.h.inc"
@ -47,6 +48,7 @@ class DLTIDialect;
namespace mlir {
namespace test {
void registerTestDialect(DialectRegistry &registry);
void populateTestReductionPatterns(RewritePatternSet &patterns);
} // namespace test
} // namespace mlir

View File

@ -2113,4 +2113,19 @@ def DataLayoutQueryOp : TEST_Op<"data_layout_query"> {
let results = (outs AnyType:$res);
}
//===----------------------------------------------------------------------===//
// Test Reducer Patterns
//===----------------------------------------------------------------------===//
def OpCrashLong : TEST_Op<"op_crash_long"> {
let arguments = (ins I32, I32, I32);
let results = (outs I32);
}
def OpCrashShort : TEST_Op<"op_crash_short"> {
let results = (outs I32);
}
def : Pat<(OpCrashLong $_, $_, $_), (OpCrashShort)>;
#endif // TEST_OPS

View File

@ -58,6 +58,14 @@ namespace {
#include "TestPatterns.inc"
} // end anonymous namespace
//===----------------------------------------------------------------------===//
// Test Reduce Pattern Interface
//===----------------------------------------------------------------------===//
void mlir::test::populateTestReductionPatterns(RewritePatternSet &patterns) {
populateWithGenerated(patterns);
}
//===----------------------------------------------------------------------===//
// Canonicalizer Driver.
//===----------------------------------------------------------------------===//

View File

@ -38,7 +38,7 @@ void TestReducer::runOnFunction() {
op.walk([&](Operation *op) {
StringRef opName = op->getName().getStringRef();
if (opName == "test.crashOp") {
if (opName.contains("op_crash")) {
llvm::errs() << "MLIR Reducer Test generated failure: Found "
"\"crashOp\" operation\n";
exit(1);

View File

@ -0,0 +1,20 @@
// UNSUPPORTED: system-windows
// RUN: mlir-reduce %s -reduction-tree='traversal-mode=0 test=%S/failure-test.sh' | FileCheck %s
// "test.op_crash_long" should be replaced with a shorter form "test.op_crash_short".
// CHECK-NOT: func @simple1() {
func @simple1() {
return
}
// CHECK-LABEL: func @simple2(%arg0: i32, %arg1: i32, %arg2: i32) {
func @simple2(%arg0: i32, %arg1: i32, %arg2: i32) {
// CHECK-LABEL: %0 = "test.op_crash_short"() : () -> i32
%0 = "test.op_crash_long" (%arg0, %arg1, %arg2) : (i32, i32, i32) -> i32
return
}
// CHECK-NOT: func @simple5() {
func @simple5() {
return
}

View File

@ -12,6 +12,6 @@ func nested @dead_nested_function()
// CHECK-LABEL: func @simple1(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
func @simple1(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
"test.crashOp" () : () -> ()
"test.op_crash" () : () -> ()
return
}

View File

@ -1,5 +1,5 @@
// UNSUPPORTED: system-windows
// RUN: mlir-reduce %s -reduction-tree='op-reducer=func traversal-mode=0 test=%S/failure-test.sh' | FileCheck %s
// RUN: mlir-reduce %s -reduction-tree='traversal-mode=0 test=%S/failure-test.sh' | FileCheck %s
// This input should be reduced by the pass pipeline so that only
// the @simple5 function remains as this is the shortest function
// containing the interesting behavior.
@ -16,7 +16,7 @@ func @simple2() {
// CHECK-LABEL: func @simple3() {
func @simple3() {
"test.crashOp" () : () -> ()
"test.op_crash" () : () -> ()
return
}
@ -29,7 +29,7 @@ func @simple4(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
%0 = memref.alloc() : memref<2xf32>
br ^bb3(%0 : memref<2xf32>)
^bb3(%1: memref<2xf32>):
"test.crashOp"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
"test.op_crash"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
return
}

View File

@ -1,5 +1,5 @@
// UNSUPPORTED: system-windows
// RUN: mlir-reduce %s -reduction-tree='op-reducer=func traversal-mode=0 test=%S/test.sh'
// RUN: mlir-reduce %s -reduction-tree='traversal-mode=0 test=%S/test.sh'
func @simple1(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
cond_br %arg0, ^bb1, ^bb2

View File

@ -2,6 +2,6 @@
// RUN: not mlir-opt %s -test-mlir-reducer -pass-test function-reducer
func @test() {
"test.crashOp"() : () -> ()
"test.op_crash"() : () -> ()
return
}

View File

@ -43,9 +43,6 @@ set(LIBS
)
add_llvm_tool(mlir-reduce
OptReductionPass.cpp
ReductionNode.cpp
ReductionTreePass.cpp
mlir-reduce.cpp
ADDITIONAL_HEADER_DIRS

View File

@ -1,107 +0,0 @@
//===- ReductionTreePass.cpp - ReductionTreePass Implementation -----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the Reduction Tree Pass class. It provides a framework for
// the implementation of different reduction passes in the MLIR Reduce tool. It
// allows for custom specification of the variant generation behavior. It
// implements methods that define the different possible traversals of the
// reduction tree.
//
//===----------------------------------------------------------------------===//
#include "mlir/Reducer/ReductionTreePass.h"
#include "mlir/Reducer/Passes.h"
#include "llvm/Support/Allocator.h"
using namespace mlir;
static std::unique_ptr<OpReducer> getOpReducer(llvm::StringRef opType) {
if (opType == ModuleOp::getOperationName())
return std::make_unique<Reducer<ModuleOp>>();
else if (opType == FuncOp::getOperationName())
return std::make_unique<Reducer<FuncOp>>();
llvm_unreachable("Now only supports two built-in ops");
}
void ReductionTreePass::runOnOperation() {
ModuleOp module = this->getOperation();
std::unique_ptr<OpReducer> reducer = getOpReducer(opReducerName);
std::vector<std::pair<int, int>> ranges = {
{0, reducer->getNumTargetOps(module)}};
llvm::SpecificBumpPtrAllocator<ReductionNode> allocator;
ReductionNode *root = allocator.Allocate();
new (root) ReductionNode(nullptr, ranges, allocator);
ModuleOp golden = module;
switch (traversalModeId) {
case TraversalMode::SinglePath:
golden = findOptimal<ReductionNode::iterator<TraversalMode::SinglePath>>(
module, std::move(reducer), root);
break;
default:
llvm_unreachable("Unsupported mode");
}
if (golden != module) {
module.getBody()->clear();
module.getBody()->getOperations().splice(module.getBody()->begin(),
golden.getBody()->getOperations());
golden->destroy();
}
}
template <typename IteratorType>
ModuleOp ReductionTreePass::findOptimal(ModuleOp module,
std::unique_ptr<OpReducer> reducer,
ReductionNode *root) {
Tester test(testerName, testerArgs);
std::pair<Tester::Interestingness, size_t> initStatus =
test.isInteresting(module);
if (initStatus.first != Tester::Interestingness::True) {
LLVM_DEBUG(llvm::dbgs() << "\nThe original input is not interested");
return module;
}
root->update(initStatus);
ReductionNode *smallestNode = root;
ModuleOp golden = module;
IteratorType iter(root);
while (iter != IteratorType::end()) {
ModuleOp cloneModule = module.clone();
ReductionNode &currentNode = *iter;
reducer->reduce(cloneModule, currentNode.getRanges());
std::pair<Tester::Interestingness, size_t> result =
test.isInteresting(cloneModule);
currentNode.update(result);
if (result.first == Tester::Interestingness::True &&
result.second < smallestNode->getSize()) {
smallestNode = &currentNode;
golden = cloneModule;
} else {
cloneModule->destroy();
}
++iter;
}
return golden;
}
std::unique_ptr<Pass> mlir::createReductionTreePass() {
return std::make_unique<ReductionTreePass>();
}

View File

@ -13,22 +13,14 @@
//
//===----------------------------------------------------------------------===//
#include <vector>
#include "mlir/InitAllDialects.h"
#include "mlir/InitAllPasses.h"
#include "mlir/Parser.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Reducer/OptReductionPass.h"
#include "mlir/Reducer/Passes.h"
#include "mlir/Reducer/Passes/OpReducer.h"
#include "mlir/Reducer/ReductionNode.h"
#include "mlir/Reducer/ReductionTreePass.h"
#include "mlir/Reducer/Tester.h"
#include "mlir/Support/FileUtilities.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/ToolOutputFile.h"