forked from OSchip/llvm-project
[mlir-reduce] Reducer refactor.
* A Reducer is a kind of RewritePattern, so it's just the same as writing graph rewrite. * ReductionTreePass operates on Operation rather than ModuleOp, so that * we are able to reduce a nested structure(e.g., module in module) by * self-nesting. Reviewed By: jpienaar, rriddle Differential Revision: https://reviews.llvm.org/D101046
This commit is contained in:
parent
26044c6a54
commit
c484c7dd9d
|
@ -1,41 +0,0 @@
|
|||
//===- OptReductionPass.h - Optimization Reduction Pass Wrapper -*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file defines the Opt Reduction Pass Wrapper. It creates a MLIR pass to
|
||||
// run any optimization pass within it and only replaces the output module with
|
||||
// the transformed version if it is smaller and interesting.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_REDUCER_OPTREDUCTIONPASS_H
|
||||
#define MLIR_REDUCER_OPTREDUCTIONPASS_H
|
||||
|
||||
#include "PassDetail.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Reducer/ReductionNode.h"
|
||||
#include "mlir/Reducer/ReductionTreePass.h"
|
||||
#include "mlir/Reducer/Tester.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class OptReductionPass : public OptReductionBase<OptReductionPass> {
|
||||
public:
|
||||
OptReductionPass() = default;
|
||||
|
||||
OptReductionPass(const OptReductionPass &srcPass) = default;
|
||||
|
||||
/// Runs the pass instance in the pass pipeline.
|
||||
void runOnOperation() override;
|
||||
};
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
#endif
|
|
@ -9,8 +9,6 @@
|
|||
#define MLIR_REDUCER_PASSES_H
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Reducer/OptReductionPass.h"
|
||||
#include "mlir/Reducer/ReductionTreePass.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
|
|
|
@ -24,14 +24,12 @@ def CommonReductionPassOptions {
|
|||
];
|
||||
}
|
||||
|
||||
def ReductionTree : Pass<"reduction-tree", "ModuleOp"> {
|
||||
def ReductionTree : Pass<"reduction-tree"> {
|
||||
let summary = "A general reduction tree pass for the MLIR Reduce Tool";
|
||||
|
||||
let constructor = "mlir::createReductionTreePass()";
|
||||
|
||||
let options = [
|
||||
Option<"opReducerName", "op-reducer", "std::string", /* default */"",
|
||||
"The OpReducer to reduce the module">,
|
||||
Option<"traversalModeId", "traversal-mode", "unsigned",
|
||||
/* default */"0", "The graph traversal mode">,
|
||||
] # CommonReductionPassOptions.options;
|
||||
|
|
|
@ -1,76 +0,0 @@
|
|||
//===- OpReducer.h - MLIR Reduce Operation Reducer ------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file defines the OpReducer class. It defines a variant generator method
|
||||
// with the purpose of producing different variants by eliminating a
|
||||
// parameterizable type of operations from the parent module.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_REDUCER_PASSES_OPREDUCER_H
|
||||
#define MLIR_REDUCER_PASSES_OPREDUCER_H
|
||||
|
||||
#include <limits>
|
||||
|
||||
#include "mlir/Reducer/ReductionNode.h"
|
||||
#include "mlir/Reducer/Tester.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class OpReducer {
|
||||
public:
|
||||
virtual ~OpReducer() = default;
|
||||
/// According to rangeToKeep, try to reduce the given module. We implicitly
|
||||
/// number each interesting operation and rangeToKeep indicates that if an
|
||||
/// operation's number falls into certain range, then we will not try to
|
||||
/// reduce that operation.
|
||||
virtual void reduce(ModuleOp module,
|
||||
ArrayRef<ReductionNode::Range> rangeToKeep) = 0;
|
||||
/// Return the number of certain kind of operations that we would like to
|
||||
/// reduce. This can be used to build a range map to exclude uninterested
|
||||
/// operations.
|
||||
virtual int getNumTargetOps(ModuleOp module) const = 0;
|
||||
};
|
||||
|
||||
/// Reducer is a helper class to remove potential uninteresting operations from
|
||||
/// module.
|
||||
template <typename OpType>
|
||||
class Reducer : public OpReducer {
|
||||
public:
|
||||
~Reducer() override = default;
|
||||
|
||||
int getNumTargetOps(ModuleOp module) const override {
|
||||
return std::distance(module.getOps<OpType>().begin(),
|
||||
module.getOps<OpType>().end());
|
||||
}
|
||||
|
||||
void reduce(ModuleOp module,
|
||||
ArrayRef<ReductionNode::Range> rangeToKeep) override {
|
||||
std::vector<Operation *> opsToRemove;
|
||||
size_t keepIndex = 0;
|
||||
|
||||
for (auto op : enumerate(module.getOps<OpType>())) {
|
||||
int index = op.index();
|
||||
if (keepIndex < rangeToKeep.size() &&
|
||||
index == rangeToKeep[keepIndex].second)
|
||||
++keepIndex;
|
||||
if (keepIndex == rangeToKeep.size() ||
|
||||
index < rangeToKeep[keepIndex].first)
|
||||
opsToRemove.push_back(op.value());
|
||||
}
|
||||
|
||||
for (Operation *o : opsToRemove) {
|
||||
o->dropAllUses();
|
||||
o->erase();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
#endif
|
|
@ -21,19 +21,25 @@
|
|||
#include <vector>
|
||||
|
||||
#include "mlir/Reducer/Tester.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/Support/Allocator.h"
|
||||
#include "llvm/Support/ToolOutputFile.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class ModuleOp;
|
||||
class Region;
|
||||
|
||||
/// Defines the traversal method options to be used in the reduction tree
|
||||
/// traversal.
|
||||
enum TraversalMode { SinglePath, Backtrack, MultiPath };
|
||||
|
||||
/// This class defines the ReductionNode which is used to generate variant and
|
||||
/// keep track of the necessary metadata for the reduction pass. The nodes are
|
||||
/// linked together in a reduction tree structure which defines the relationship
|
||||
/// between all the different generated variants.
|
||||
/// ReductionTreePass will build a reduction tree during module reduction and
|
||||
/// the ReductionNode represents the vertex of the tree. A ReductionNode records
|
||||
/// the information such as the reduced module, how this node is reduced from
|
||||
/// the parent node, etc. This information will be used to construct a reduction
|
||||
/// path to reduce the certain module.
|
||||
class ReductionNode {
|
||||
public:
|
||||
template <TraversalMode mode>
|
||||
|
@ -44,23 +50,46 @@ public:
|
|||
ReductionNode(ReductionNode *parent, std::vector<Range> range,
|
||||
llvm::SpecificBumpPtrAllocator<ReductionNode> &allocator);
|
||||
|
||||
ReductionNode *getParent() const;
|
||||
ReductionNode *getParent() const { return parent; }
|
||||
|
||||
size_t getSize() const;
|
||||
/// If the ReductionNode hasn't been tested the interestingness, it'll be the
|
||||
/// same module as the one in the parent node. Otherwise, the returned module
|
||||
/// will have been applied certain reduction strategies. Note that it's not
|
||||
/// necessary to be an interesting case or a reduced module (has smaller size
|
||||
/// than parent's).
|
||||
ModuleOp getModule() const { return module; }
|
||||
|
||||
/// Return the region we're reducing.
|
||||
Region &getRegion() const { return *region; }
|
||||
|
||||
/// Return the size of the module.
|
||||
size_t getSize() const { return size; }
|
||||
|
||||
/// Returns true if the module exhibits the interesting behavior.
|
||||
Tester::Interestingness isInteresting() const;
|
||||
Tester::Interestingness isInteresting() const { return interesting; }
|
||||
|
||||
std::vector<Range> getRanges() const;
|
||||
/// Return the range information that how this node is reduced from the parent
|
||||
/// node.
|
||||
ArrayRef<Range> getStartRanges() const { return startRanges; }
|
||||
|
||||
std::vector<ReductionNode *> &getVariants();
|
||||
/// Return the range set we are using to generate variants.
|
||||
ArrayRef<Range> getRanges() const { return ranges; }
|
||||
|
||||
/// Return the generated variants(the child nodes).
|
||||
ArrayRef<ReductionNode *> getVariants() const { return variants; }
|
||||
|
||||
/// Split the ranges and generate new variants.
|
||||
std::vector<ReductionNode *> generateNewVariants();
|
||||
ArrayRef<ReductionNode *> generateNewVariants();
|
||||
|
||||
/// Update the interestingness result from tester.
|
||||
void update(std::pair<Tester::Interestingness, size_t> result);
|
||||
|
||||
/// Each Reduction Node contains a copy of module for applying rewrite
|
||||
/// patterns. In addition, we only apply rewrite patterns in a certain region.
|
||||
/// In init(), we will duplicate the module from parent node and locate the
|
||||
/// corresponding region.
|
||||
LogicalResult initialize(ModuleOp parentModule, Region &parentRegion);
|
||||
|
||||
private:
|
||||
/// A custom BFS iterator. The difference between
|
||||
/// llvm/ADT/BreadthFirstIterator.h is the graph we're exploring is dynamic.
|
||||
|
@ -87,8 +116,7 @@ private:
|
|||
BaseIterator &operator++() {
|
||||
ReductionNode *top = visitQueue.front();
|
||||
visitQueue.pop();
|
||||
std::vector<ReductionNode *> neighbors = getNeighbors(top);
|
||||
for (ReductionNode *node : neighbors)
|
||||
for (ReductionNode *node : getNeighbors(top))
|
||||
visitQueue.push(node);
|
||||
return *this;
|
||||
}
|
||||
|
@ -103,7 +131,7 @@ private:
|
|||
ReductionNode *operator->() const { return visitQueue.front(); }
|
||||
|
||||
protected:
|
||||
std::vector<ReductionNode *> getNeighbors(ReductionNode *node) {
|
||||
ArrayRef<ReductionNode *> getNeighbors(ReductionNode *node) {
|
||||
return static_cast<T *>(this)->getNeighbors(node);
|
||||
}
|
||||
|
||||
|
@ -111,21 +139,42 @@ private:
|
|||
std::queue<ReductionNode *> visitQueue;
|
||||
};
|
||||
|
||||
/// The size of module after applying the range constraints.
|
||||
/// This is a copy of module from parent node. All the reducer patterns will
|
||||
/// be applied to this instance.
|
||||
ModuleOp module;
|
||||
|
||||
/// The region of certain operation we're reducing in the module
|
||||
Region *region;
|
||||
|
||||
/// The node we are reduced from. It means we will be in variants of parent
|
||||
/// node.
|
||||
ReductionNode *parent;
|
||||
|
||||
/// The size of module after applying the reducer patterns with range
|
||||
/// constraints. This is only valid while the interestingness has been tested.
|
||||
size_t size;
|
||||
|
||||
/// This is true if the module has been evaluated and it exhibits the
|
||||
/// interesting behavior.
|
||||
Tester::Interestingness interesting;
|
||||
|
||||
ReductionNode *parent;
|
||||
|
||||
/// We will only keep the operation with index falls into the ranges.
|
||||
/// For example, number each function in a certain module and then we will
|
||||
/// remove the functions with index outside the ranges and see if the
|
||||
/// resulting module is still interesting.
|
||||
/// `ranges` represents the selected subset of operations in the region. We
|
||||
/// implictly number each operation in the region and ReductionTreePass will
|
||||
/// apply reducer patterns on the operation falls into the `ranges`. We will
|
||||
/// generate new ReductionNode with subset of `ranges` to see if we can do
|
||||
/// further reduction. we may split the element in the `ranges` so that we can
|
||||
/// have more subset variants from `ranges`.
|
||||
/// Note that after applying the reducer patterns the number of operation in
|
||||
/// the region may have changed, we need to update the `ranges` after that.
|
||||
std::vector<Range> ranges;
|
||||
|
||||
/// `startRanges` records the ranges of operations selected from the parent
|
||||
/// node to produce this ReductionNode. It can be used to construct the
|
||||
/// reduction path from the root. I.e., if we apply the same reducer patterns
|
||||
/// and `startRanges` selection on the parent region, we will get the same
|
||||
/// module as this node.
|
||||
const std::vector<Range> startRanges;
|
||||
|
||||
/// This points to the child variants that were created using this node as a
|
||||
/// starting point.
|
||||
std::vector<ReductionNode *> variants;
|
||||
|
@ -139,9 +188,9 @@ class ReductionNode::iterator<SinglePath>
|
|||
: public BaseIterator<iterator<SinglePath>> {
|
||||
friend BaseIterator<iterator<SinglePath>>;
|
||||
using BaseIterator::BaseIterator;
|
||||
std::vector<ReductionNode *> getNeighbors(ReductionNode *node);
|
||||
ArrayRef<ReductionNode *> getNeighbors(ReductionNode *node);
|
||||
};
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
#endif
|
||||
#endif // MLIR_REDUCER_REDUCTIONNODE_H
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
//===- ReducePatternInterface.h - Collecting Reduce Patterns ----*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_REDUCER_REDUCTIONPATTERNINTERFACE_H
|
||||
#define MLIR_REDUCER_REDUCTIONPATTERNINTERFACE_H
|
||||
|
||||
#include "mlir/IR/DialectInterface.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class RewritePatternSet;
|
||||
|
||||
/// This is used to report the reduction patterns for a Dialect. While using
|
||||
/// mlir-reduce to reduce a module, we may want to transform certain cases into
|
||||
/// simpler forms by applying certain rewrite patterns. Implement the
|
||||
/// `populateReductionPatterns` to report those patterns by adding them to the
|
||||
/// RewritePatternSet.
|
||||
///
|
||||
/// Example:
|
||||
/// MyDialectReductionPattern::populateReductionPatterns(
|
||||
/// RewritePatternSet &patterns) {
|
||||
/// patterns.add<TensorOpReduction>(patterns.getContext());
|
||||
/// }
|
||||
///
|
||||
/// For DRR, mlir-tblgen will generate a helper function
|
||||
/// `populateWithGenerated` which has the same signature therefore you can
|
||||
/// delegate to the helper function as well.
|
||||
///
|
||||
/// Example:
|
||||
/// MyDialectReductionPattern::populateReductionPatterns(
|
||||
/// RewritePatternSet &patterns) {
|
||||
/// // Include the autogen file somewhere above.
|
||||
/// populateWithGenerated(patterns);
|
||||
/// }
|
||||
class DialectReductionPatternInterface
|
||||
: public DialectInterface::Base<DialectReductionPatternInterface> {
|
||||
public:
|
||||
/// Patterns provided here are intended to transform operations from a complex
|
||||
/// form to a simpler form, without breaking the semantics of the program
|
||||
/// being reduced. For example, you may want to replace the
|
||||
/// tensor<?xindex> with a known rank and type, e.g. tensor<1xi32>, or
|
||||
/// replacing an operation with a constant.
|
||||
virtual void populateReductionPatterns(RewritePatternSet &patterns) const = 0;
|
||||
|
||||
protected:
|
||||
DialectReductionPatternInterface(Dialect *dialect) : Base(dialect) {}
|
||||
};
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_REDUCER_REDUCTIONPATTERNINTERFACE_H
|
|
@ -1,50 +0,0 @@
|
|||
//===- ReductionTreePass.h - Reduction Tree Pass Implementation -*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file defines the Reduction Tree Pass class. It provides a framework for
|
||||
// the implementation of different reduction passes in the MLIR Reduce tool. It
|
||||
// allows for custom specification of the variant generation behavior. It
|
||||
// implements methods that define the different possible traversals of the
|
||||
// reduction tree.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_REDUCER_REDUCTIONTREEPASS_H
|
||||
#define MLIR_REDUCER_REDUCTIONTREEPASS_H
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "PassDetail.h"
|
||||
#include "ReductionNode.h"
|
||||
#include "mlir/Reducer/Passes/OpReducer.h"
|
||||
#include "mlir/Reducer/Tester.h"
|
||||
|
||||
#define DEBUG_TYPE "mlir-reduce"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
/// This class defines the Reduction Tree Pass. It provides a framework to
|
||||
/// to implement a reduction pass using a tree structure to keep track of the
|
||||
/// generated reduced variants.
|
||||
class ReductionTreePass : public ReductionTreeBase<ReductionTreePass> {
|
||||
public:
|
||||
ReductionTreePass() = default;
|
||||
ReductionTreePass(const ReductionTreePass &pass) = default;
|
||||
|
||||
/// Runs the pass instance in the pass pipeline.
|
||||
void runOnOperation() override;
|
||||
|
||||
private:
|
||||
template <typename IteratorType>
|
||||
ModuleOp findOptimal(ModuleOp module, std::unique_ptr<OpReducer> reducer,
|
||||
ReductionNode *node);
|
||||
};
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
#endif
|
|
@ -1,7 +1,13 @@
|
|||
add_mlir_library(MLIRReduce
|
||||
OptReductionPass.cpp
|
||||
ReductionNode.cpp
|
||||
ReductionTreePass.cpp
|
||||
Tester.cpp
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRPass
|
||||
MLIRRewrite
|
||||
MLIRTransformUtils
|
||||
)
|
||||
|
||||
mlir_check_all_link_libraries(MLIRReduce)
|
||||
|
|
|
@ -12,15 +12,27 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Reducer/OptReductionPass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Pass/PassRegistry.h"
|
||||
#include "mlir/Reducer/PassDetail.h"
|
||||
#include "mlir/Reducer/Passes.h"
|
||||
#include "mlir/Reducer/Tester.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
||||
#define DEBUG_TYPE "mlir-reduce"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
|
||||
class OptReductionPass : public OptReductionBase<OptReductionPass> {
|
||||
public:
|
||||
/// Runs the pass instance in the pass pipeline.
|
||||
void runOnOperation() override;
|
||||
};
|
||||
|
||||
} // end anonymous namespace
|
||||
|
||||
/// Runs the pass instance in the pass pipeline.
|
||||
void OptReductionPass::runOnOperation() {
|
||||
LLVM_DEBUG(llvm::dbgs() << "\nOptimization Reduction pass: ");
|
|
@ -15,6 +15,7 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Reducer/ReductionNode.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
@ -23,102 +24,102 @@
|
|||
using namespace mlir;
|
||||
|
||||
ReductionNode::ReductionNode(
|
||||
ReductionNode *parent, std::vector<Range> ranges,
|
||||
ReductionNode *parentNode, std::vector<Range> ranges,
|
||||
llvm::SpecificBumpPtrAllocator<ReductionNode> &allocator)
|
||||
: size(std::numeric_limits<size_t>::max()),
|
||||
interesting(Tester::Interestingness::Untested),
|
||||
/// Root node will have the parent pointer point to themselves.
|
||||
parent(parent == nullptr ? this : parent), ranges(ranges),
|
||||
allocator(allocator) {}
|
||||
|
||||
/// Returns the size in bytes of the module.
|
||||
size_t ReductionNode::getSize() const { return size; }
|
||||
|
||||
ReductionNode *ReductionNode::getParent() const { return parent; }
|
||||
|
||||
/// Returns true if the module exhibits the interesting behavior.
|
||||
Tester::Interestingness ReductionNode::isInteresting() const {
|
||||
return interesting;
|
||||
/// Root node will have the parent pointer point to themselves.
|
||||
: parent(parentNode == nullptr ? this : parentNode),
|
||||
size(std::numeric_limits<size_t>::max()),
|
||||
interesting(Tester::Interestingness::Untested), ranges(ranges),
|
||||
startRanges(ranges), allocator(allocator) {
|
||||
if (parent != this)
|
||||
if (failed(initialize(parent->getModule(), parent->getRegion())))
|
||||
llvm_unreachable("unexpected initialization failure");
|
||||
}
|
||||
|
||||
std::vector<ReductionNode::Range> ReductionNode::getRanges() const {
|
||||
return ranges;
|
||||
LogicalResult ReductionNode::initialize(ModuleOp parentModule,
|
||||
Region &targetRegion) {
|
||||
// Use the mapper help us find the corresponding region after module clone.
|
||||
BlockAndValueMapping mapper;
|
||||
module = cast<ModuleOp>(parentModule->clone(mapper));
|
||||
// Use the first block of targetRegion to locate the cloned region.
|
||||
Block *block = mapper.lookup(&*targetRegion.begin());
|
||||
region = block->getParent();
|
||||
return success();
|
||||
}
|
||||
|
||||
std::vector<ReductionNode *> &ReductionNode::getVariants() { return variants; }
|
||||
|
||||
#include <iostream>
|
||||
|
||||
/// If we haven't explored any variants from this node, we will create N
|
||||
/// variants, N is the length of `ranges` if N > 1. Otherwise, we will split the
|
||||
/// max element in `ranges` and create 2 new variants for each call.
|
||||
std::vector<ReductionNode *> ReductionNode::generateNewVariants() {
|
||||
std::vector<ReductionNode *> newNodes;
|
||||
ArrayRef<ReductionNode *> ReductionNode::generateNewVariants() {
|
||||
int oldNumVariant = getVariants().size();
|
||||
|
||||
auto createNewNode = [this](std::vector<Range> ranges) {
|
||||
return new (allocator.Allocate())
|
||||
ReductionNode(this, std::move(ranges), allocator);
|
||||
};
|
||||
|
||||
// If we haven't created new variant, then we can create varients by removing
|
||||
// each of them respectively. For example, given {{1, 3}, {4, 9}}, we can
|
||||
// produce variants with range {{1, 3}} and {{4, 9}}.
|
||||
if (variants.size() == 0 && ranges.size() != 1) {
|
||||
for (const Range &range : ranges) {
|
||||
std::vector<Range> subRanges = ranges;
|
||||
if (variants.size() == 0 && getRanges().size() > 1) {
|
||||
for (const Range &range : getRanges()) {
|
||||
std::vector<Range> subRanges = getRanges();
|
||||
llvm::erase_value(subRanges, range);
|
||||
ReductionNode *newNode = allocator.Allocate();
|
||||
new (newNode) ReductionNode(this, subRanges, allocator);
|
||||
newNodes.push_back(newNode);
|
||||
variants.push_back(newNode);
|
||||
variants.push_back(createNewNode(std::move(subRanges)));
|
||||
}
|
||||
|
||||
return newNodes;
|
||||
return getVariants().drop_front(oldNumVariant);
|
||||
}
|
||||
|
||||
// At here, we have created the type of variants mentioned above. We would
|
||||
// like to split the max range into 2 to create 2 new variants. Continue on
|
||||
// the above example, we split the range {4, 9} into {4, 6}, {6, 9}, and
|
||||
// create two variants with range {{1, 3}, {4, 6}} and {{1, 3}, {6, 9}}. The
|
||||
// result ranges vector will be {{1, 3}, {4, 6}, {6, 9}}.
|
||||
// final ranges vector will be {{1, 3}, {4, 6}, {6, 9}}.
|
||||
auto maxElement = std::max_element(
|
||||
ranges.begin(), ranges.end(), [](const Range &lhs, const Range &rhs) {
|
||||
return (lhs.second - lhs.first) > (rhs.second - rhs.first);
|
||||
});
|
||||
|
||||
// We can't split range with lenght 1, which means we can't produce new
|
||||
// The length of range is less than 1, we can't split it to create new
|
||||
// variant.
|
||||
if (maxElement->second - maxElement->first == 1)
|
||||
if (maxElement->second - maxElement->first <= 1)
|
||||
return {};
|
||||
|
||||
auto createNewNode = [this](const std::vector<Range> &ranges) {
|
||||
ReductionNode *newNode = allocator.Allocate();
|
||||
new (newNode) ReductionNode(this, ranges, allocator);
|
||||
return newNode;
|
||||
};
|
||||
|
||||
Range maxRange = *maxElement;
|
||||
std::vector<Range> subRanges = ranges;
|
||||
std::vector<Range> subRanges = getRanges();
|
||||
auto subRangesIter = subRanges.begin() + (maxElement - ranges.begin());
|
||||
int half = (maxRange.first + maxRange.second) / 2;
|
||||
*subRangesIter = std::make_pair(maxRange.first, half);
|
||||
newNodes.push_back(createNewNode(subRanges));
|
||||
variants.push_back(createNewNode(subRanges));
|
||||
*subRangesIter = std::make_pair(half, maxRange.second);
|
||||
newNodes.push_back(createNewNode(subRanges));
|
||||
variants.push_back(createNewNode(std::move(subRanges)));
|
||||
|
||||
variants.insert(variants.end(), newNodes.begin(), newNodes.end());
|
||||
auto it = ranges.insert(maxElement, std::make_pair(half, maxRange.second));
|
||||
it = ranges.insert(it, std::make_pair(maxRange.first, half));
|
||||
// Remove the range that has been split.
|
||||
ranges.erase(it + 2);
|
||||
|
||||
return newNodes;
|
||||
return getVariants().drop_front(oldNumVariant);
|
||||
}
|
||||
|
||||
void ReductionNode::update(std::pair<Tester::Interestingness, size_t> result) {
|
||||
std::tie(interesting, size) = result;
|
||||
// After applying reduction, the number of operation in the region may have
|
||||
// changed. Non-interesting case won't be explored thus it's safe to keep it
|
||||
// in a stale status.
|
||||
if (interesting == Tester::Interestingness::True) {
|
||||
// This module may has been updated. Reset the range.
|
||||
ranges.clear();
|
||||
ranges.push_back({0, std::distance(region->op_begin(), region->op_end())});
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<ReductionNode *>
|
||||
ArrayRef<ReductionNode *>
|
||||
ReductionNode::iterator<SinglePath>::getNeighbors(ReductionNode *node) {
|
||||
// Single Path: Traverses the smallest successful variant at each level until
|
||||
// no new successful variants can be created at that level.
|
||||
llvm::ArrayRef<ReductionNode *> variantsFromParent =
|
||||
ArrayRef<ReductionNode *> variantsFromParent =
|
||||
node->getParent()->getVariants();
|
||||
|
||||
// The parent node created several variants and they may be waiting for
|
||||
|
@ -139,7 +140,8 @@ ReductionNode::iterator<SinglePath>::getNeighbors(ReductionNode *node) {
|
|||
smallest = node;
|
||||
}
|
||||
|
||||
if (smallest != nullptr) {
|
||||
if (smallest != nullptr &&
|
||||
smallest->getSize() < node->getParent()->getSize()) {
|
||||
// We got a smallest one, keep traversing from this node.
|
||||
node = smallest;
|
||||
} else {
|
|
@ -0,0 +1,247 @@
|
|||
//===- ReductionTreePass.cpp - ReductionTreePass Implementation -----------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file defines the Reduction Tree Pass class. It provides a framework for
|
||||
// the implementation of different reduction passes in the MLIR Reduce tool. It
|
||||
// allows for custom specification of the variant generation behavior. It
|
||||
// implements methods that define the different possible traversals of the
|
||||
// reduction tree.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/IR/DialectInterface.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/Reducer/PassDetail.h"
|
||||
#include "mlir/Reducer/Passes.h"
|
||||
#include "mlir/Reducer/ReductionNode.h"
|
||||
#include "mlir/Reducer/ReductionPatternInterface.h"
|
||||
#include "mlir/Reducer/Tester.h"
|
||||
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/Allocator.h"
|
||||
#include "llvm/Support/ManagedStatic.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
/// We implicitly number each operation in the region and if an operation's
|
||||
/// number falls into rangeToKeep, we need to keep it and apply the given
|
||||
/// rewrite patterns on it.
|
||||
static void applyPatterns(Region ®ion,
|
||||
const FrozenRewritePatternSet &patterns,
|
||||
ArrayRef<ReductionNode::Range> rangeToKeep,
|
||||
bool eraseOpNotInRange) {
|
||||
std::vector<Operation *> opsNotInRange;
|
||||
std::vector<Operation *> opsInRange;
|
||||
size_t keepIndex = 0;
|
||||
for (auto op : enumerate(region.getOps())) {
|
||||
int index = op.index();
|
||||
if (keepIndex < rangeToKeep.size() &&
|
||||
index == rangeToKeep[keepIndex].second)
|
||||
++keepIndex;
|
||||
if (keepIndex == rangeToKeep.size() || index < rangeToKeep[keepIndex].first)
|
||||
opsNotInRange.push_back(&op.value());
|
||||
else
|
||||
opsInRange.push_back(&op.value());
|
||||
}
|
||||
|
||||
// `applyOpPatternsAndFold` may erase the ops so we can't do the pattern
|
||||
// matching in above iteration. Besides, erase op not-in-range may end up in
|
||||
// invalid module, so `applyOpPatternsAndFold` should come before that
|
||||
// transform.
|
||||
for (Operation *op : opsInRange)
|
||||
// `applyOpPatternsAndFold` returns whether the op is convered. Omit it
|
||||
// because we don't have expectation this reduction will be success or not.
|
||||
(void)applyOpPatternsAndFold(op, patterns);
|
||||
|
||||
if (eraseOpNotInRange)
|
||||
for (Operation *op : opsNotInRange) {
|
||||
op->dropAllUses();
|
||||
op->erase();
|
||||
}
|
||||
}
|
||||
|
||||
/// We will apply the reducer patterns to the operations in the ranges specified
|
||||
/// by ReductionNode. Note that we are not able to remove an operation without
|
||||
/// replacing it with another valid operation. However, The validity of module
|
||||
/// reduction is based on the Tester provided by the user and that means certain
|
||||
/// invalid module is still interested by the use. Thus we provide an
|
||||
/// alternative way to remove operations, which is using `eraseOpNotInRange` to
|
||||
/// erase the operations not in the range specified by ReductionNode.
|
||||
template <typename IteratorType>
|
||||
static void findOptimal(ModuleOp module, Region ®ion,
|
||||
const FrozenRewritePatternSet &patterns,
|
||||
const Tester &test, bool eraseOpNotInRange) {
|
||||
std::pair<Tester::Interestingness, size_t> initStatus =
|
||||
test.isInteresting(module);
|
||||
// While exploring the reduction tree, we always branch from an interesting
|
||||
// node. Thus the root node must be interesting.
|
||||
if (initStatus.first != Tester::Interestingness::True)
|
||||
return;
|
||||
|
||||
llvm::SpecificBumpPtrAllocator<ReductionNode> allocator;
|
||||
|
||||
std::vector<ReductionNode::Range> ranges{
|
||||
{0, std::distance(region.op_begin(), region.op_end())}};
|
||||
|
||||
ReductionNode *root = allocator.Allocate();
|
||||
new (root) ReductionNode(nullptr, std::move(ranges), allocator);
|
||||
// Duplicate the module for root node and locate the region in the copy.
|
||||
if (failed(root->initialize(module, region)))
|
||||
llvm_unreachable("unexpected initialization failure");
|
||||
root->update(initStatus);
|
||||
|
||||
ReductionNode *smallestNode = root;
|
||||
IteratorType iter(root);
|
||||
|
||||
while (iter != IteratorType::end()) {
|
||||
ReductionNode ¤tNode = *iter;
|
||||
Region &curRegion = currentNode.getRegion();
|
||||
|
||||
applyPatterns(curRegion, patterns, currentNode.getRanges(),
|
||||
eraseOpNotInRange);
|
||||
currentNode.update(test.isInteresting(currentNode.getModule()));
|
||||
|
||||
if (currentNode.isInteresting() == Tester::Interestingness::True &&
|
||||
currentNode.getSize() < smallestNode->getSize())
|
||||
smallestNode = ¤tNode;
|
||||
|
||||
++iter;
|
||||
}
|
||||
|
||||
// At here, we have found an optimal path to reduce the given region. Retrieve
|
||||
// the path and apply the reducer to it.
|
||||
SmallVector<ReductionNode *> trace;
|
||||
ReductionNode *curNode = smallestNode;
|
||||
trace.push_back(curNode);
|
||||
while (curNode != root) {
|
||||
curNode = curNode->getParent();
|
||||
trace.push_back(curNode);
|
||||
}
|
||||
|
||||
// Reduce the region through the optimal path.
|
||||
while (!trace.empty()) {
|
||||
ReductionNode *top = trace.pop_back_val();
|
||||
applyPatterns(region, patterns, top->getStartRanges(), eraseOpNotInRange);
|
||||
}
|
||||
|
||||
if (test.isInteresting(module).first != Tester::Interestingness::True)
|
||||
llvm::report_fatal_error("Reduced module is not interesting");
|
||||
if (test.isInteresting(module).second != smallestNode->getSize())
|
||||
llvm::report_fatal_error(
|
||||
"Reduced module doesn't have consistent size with smallestNode");
|
||||
}
|
||||
|
||||
template <typename IteratorType>
|
||||
static void findOptimal(ModuleOp module, Region ®ion,
|
||||
const FrozenRewritePatternSet &patterns,
|
||||
const Tester &test) {
|
||||
// We separate the reduction process into 2 steps, the first one is to erase
|
||||
// redundant operations and the second one is to apply the reducer patterns.
|
||||
|
||||
// In the first phase, we don't apply any patterns so that we only select the
|
||||
// range of operations to keep to the module stay interesting.
|
||||
findOptimal<IteratorType>(module, region, /*patterns=*/{}, test,
|
||||
/*eraseOpNotInRange=*/true);
|
||||
// In the second phase, we suppose that no operation is redundant, so we try
|
||||
// to rewrite the operation into simpler form.
|
||||
findOptimal<IteratorType>(module, region, patterns, test,
|
||||
/*eraseOpNotInRange=*/false);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Reduction Pattern Interface Collection
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class ReductionPatternInterfaceCollection
|
||||
: public DialectInterfaceCollection<DialectReductionPatternInterface> {
|
||||
public:
|
||||
using Base::Base;
|
||||
|
||||
// Collect the reduce patterns defined by each dialect.
|
||||
void populateReductionPatterns(RewritePatternSet &pattern) const {
|
||||
for (const DialectReductionPatternInterface &interface : *this)
|
||||
interface.populateReductionPatterns(pattern);
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ReductionTreePass
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// This class defines the Reduction Tree Pass. It provides a framework to
|
||||
/// to implement a reduction pass using a tree structure to keep track of the
|
||||
/// generated reduced variants.
|
||||
class ReductionTreePass : public ReductionTreeBase<ReductionTreePass> {
|
||||
public:
|
||||
ReductionTreePass() = default;
|
||||
ReductionTreePass(const ReductionTreePass &pass) = default;
|
||||
|
||||
LogicalResult initialize(MLIRContext *context) override;
|
||||
|
||||
/// Runs the pass instance in the pass pipeline.
|
||||
void runOnOperation() override;
|
||||
|
||||
private:
|
||||
void reduceOp(ModuleOp module, Region ®ion);
|
||||
|
||||
FrozenRewritePatternSet reducerPatterns;
|
||||
};
|
||||
|
||||
} // end anonymous namespace
|
||||
|
||||
LogicalResult ReductionTreePass::initialize(MLIRContext *context) {
|
||||
RewritePatternSet patterns(context);
|
||||
ReductionPatternInterfaceCollection reducePatternCollection(context);
|
||||
reducePatternCollection.populateReductionPatterns(patterns);
|
||||
reducerPatterns = std::move(patterns);
|
||||
return success();
|
||||
}
|
||||
|
||||
void ReductionTreePass::runOnOperation() {
|
||||
Operation *topOperation = getOperation();
|
||||
while (topOperation->getParentOp() != nullptr)
|
||||
topOperation = topOperation->getParentOp();
|
||||
ModuleOp module = cast<ModuleOp>(topOperation);
|
||||
|
||||
SmallVector<Operation *, 8> workList;
|
||||
workList.push_back(getOperation());
|
||||
|
||||
do {
|
||||
Operation *op = workList.pop_back_val();
|
||||
|
||||
for (Region ®ion : op->getRegions())
|
||||
if (!region.empty())
|
||||
reduceOp(module, region);
|
||||
|
||||
for (Region ®ion : op->getRegions())
|
||||
for (Operation &op : region.getOps())
|
||||
if (op.getNumRegions() != 0)
|
||||
workList.push_back(&op);
|
||||
} while (!workList.empty());
|
||||
}
|
||||
|
||||
void ReductionTreePass::reduceOp(ModuleOp module, Region ®ion) {
|
||||
Tester test(testerName, testerArgs);
|
||||
switch (traversalModeId) {
|
||||
case TraversalMode::SinglePath:
|
||||
findOptimal<ReductionNode::iterator<TraversalMode::SinglePath>>(
|
||||
module, region, reducerPatterns, test);
|
||||
break;
|
||||
default:
|
||||
llvm_unreachable("Unsupported mode");
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<Pass> mlir::createReductionTreePass() {
|
||||
return std::make_unique<ReductionTreePass>();
|
||||
}
|
|
@ -15,7 +15,7 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Reducer/Tester.h"
|
||||
|
||||
#include "mlir/IR/Verifier.h"
|
||||
#include "llvm/Support/ToolOutputFile.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
@ -25,6 +25,12 @@ Tester::Tester(StringRef scriptName, ArrayRef<std::string> scriptArgs)
|
|||
|
||||
std::pair<Tester::Interestingness, size_t>
|
||||
Tester::isInteresting(ModuleOp module) const {
|
||||
// The reduced module should always be vaild, or we may end up retaining the
|
||||
// error message by an invalid case. Besides, an invalid module may not be
|
||||
// able to print properly.
|
||||
if (failed(verify(module)))
|
||||
return std::make_pair(Interestingness::False, /*size=*/0);
|
||||
|
||||
SmallString<128> filepath;
|
||||
int fd;
|
||||
|
||||
|
@ -50,7 +56,6 @@ Tester::isInteresting(ModuleOp module) const {
|
|||
/// true if the interesting behavior is present in the test case or false
|
||||
/// otherwise.
|
||||
Tester::Interestingness Tester::isInteresting(StringRef testCase) const {
|
||||
|
||||
std::vector<StringRef> testerArgs;
|
||||
testerArgs.push_back(testCase);
|
||||
|
||||
|
|
|
@ -60,6 +60,7 @@ add_mlir_library(MLIRTestDialect
|
|||
MLIRInferTypeOpInterface
|
||||
MLIRLinalgTransforms
|
||||
MLIRPass
|
||||
MLIRReduce
|
||||
MLIRStandard
|
||||
MLIRStandardOpsTransforms
|
||||
MLIRTransformUtils
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
|
||||
#include "TestDialect.h"
|
||||
#include "TestAttributes.h"
|
||||
#include "TestInterfaces.h"
|
||||
#include "TestTypes.h"
|
||||
#include "mlir/Dialect/DLTI/DLTI.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
|
@ -16,6 +17,7 @@
|
|||
#include "mlir/IR/DialectImplementation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/Reducer/ReductionPatternInterface.h"
|
||||
#include "mlir/Transforms/FoldUtils.h"
|
||||
#include "mlir/Transforms/InliningUtils.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
|
@ -170,6 +172,18 @@ struct TestInlinerInterface : public DialectInlinerInterface {
|
|||
return builder.create<TestCastOp>(conversionLoc, resultType, input);
|
||||
}
|
||||
};
|
||||
|
||||
struct TestReductionPatternInterface : public DialectReductionPatternInterface {
|
||||
public:
|
||||
TestReductionPatternInterface(Dialect *dialect)
|
||||
: DialectReductionPatternInterface(dialect) {}
|
||||
|
||||
virtual void
|
||||
populateReductionPatterns(RewritePatternSet &patterns) const final {
|
||||
populateTestReductionPatterns(patterns);
|
||||
}
|
||||
};
|
||||
|
||||
} // end anonymous namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -207,7 +221,7 @@ void TestDialect::initialize() {
|
|||
#include "TestOps.cpp.inc"
|
||||
>();
|
||||
addInterfaces<TestOpAsmInterface, TestDialectFoldInterface,
|
||||
TestInlinerInterface>();
|
||||
TestInlinerInterface, TestReductionPatternInterface>();
|
||||
allowUnknownOperations();
|
||||
|
||||
// Instantiate our fallback op interface that we'll use on specific
|
||||
|
|
|
@ -34,6 +34,7 @@
|
|||
|
||||
namespace mlir {
|
||||
class DLTIDialect;
|
||||
class RewritePatternSet;
|
||||
} // namespace mlir
|
||||
|
||||
#include "TestOpEnums.h.inc"
|
||||
|
@ -47,6 +48,7 @@ class DLTIDialect;
|
|||
namespace mlir {
|
||||
namespace test {
|
||||
void registerTestDialect(DialectRegistry ®istry);
|
||||
void populateTestReductionPatterns(RewritePatternSet &patterns);
|
||||
} // namespace test
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -2113,4 +2113,19 @@ def DataLayoutQueryOp : TEST_Op<"data_layout_query"> {
|
|||
let results = (outs AnyType:$res);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Test Reducer Patterns
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def OpCrashLong : TEST_Op<"op_crash_long"> {
|
||||
let arguments = (ins I32, I32, I32);
|
||||
let results = (outs I32);
|
||||
}
|
||||
|
||||
def OpCrashShort : TEST_Op<"op_crash_short"> {
|
||||
let results = (outs I32);
|
||||
}
|
||||
|
||||
def : Pat<(OpCrashLong $_, $_, $_), (OpCrashShort)>;
|
||||
|
||||
#endif // TEST_OPS
|
||||
|
|
|
@ -58,6 +58,14 @@ namespace {
|
|||
#include "TestPatterns.inc"
|
||||
} // end anonymous namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Test Reduce Pattern Interface
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void mlir::test::populateTestReductionPatterns(RewritePatternSet &patterns) {
|
||||
populateWithGenerated(patterns);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Canonicalizer Driver.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -38,7 +38,7 @@ void TestReducer::runOnFunction() {
|
|||
op.walk([&](Operation *op) {
|
||||
StringRef opName = op->getName().getStringRef();
|
||||
|
||||
if (opName == "test.crashOp") {
|
||||
if (opName.contains("op_crash")) {
|
||||
llvm::errs() << "MLIR Reducer Test generated failure: Found "
|
||||
"\"crashOp\" operation\n";
|
||||
exit(1);
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
// UNSUPPORTED: system-windows
|
||||
// RUN: mlir-reduce %s -reduction-tree='traversal-mode=0 test=%S/failure-test.sh' | FileCheck %s
|
||||
// "test.op_crash_long" should be replaced with a shorter form "test.op_crash_short".
|
||||
|
||||
// CHECK-NOT: func @simple1() {
|
||||
func @simple1() {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @simple2(%arg0: i32, %arg1: i32, %arg2: i32) {
|
||||
func @simple2(%arg0: i32, %arg1: i32, %arg2: i32) {
|
||||
// CHECK-LABEL: %0 = "test.op_crash_short"() : () -> i32
|
||||
%0 = "test.op_crash_long" (%arg0, %arg1, %arg2) : (i32, i32, i32) -> i32
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-NOT: func @simple5() {
|
||||
func @simple5() {
|
||||
return
|
||||
}
|
|
@ -12,6 +12,6 @@ func nested @dead_nested_function()
|
|||
|
||||
// CHECK-LABEL: func @simple1(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
|
||||
func @simple1(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
|
||||
"test.crashOp" () : () -> ()
|
||||
"test.op_crash" () : () -> ()
|
||||
return
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
// UNSUPPORTED: system-windows
|
||||
// RUN: mlir-reduce %s -reduction-tree='op-reducer=func traversal-mode=0 test=%S/failure-test.sh' | FileCheck %s
|
||||
// RUN: mlir-reduce %s -reduction-tree='traversal-mode=0 test=%S/failure-test.sh' | FileCheck %s
|
||||
// This input should be reduced by the pass pipeline so that only
|
||||
// the @simple5 function remains as this is the shortest function
|
||||
// containing the interesting behavior.
|
||||
|
@ -16,7 +16,7 @@ func @simple2() {
|
|||
|
||||
// CHECK-LABEL: func @simple3() {
|
||||
func @simple3() {
|
||||
"test.crashOp" () : () -> ()
|
||||
"test.op_crash" () : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -29,7 +29,7 @@ func @simple4(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
|
|||
%0 = memref.alloc() : memref<2xf32>
|
||||
br ^bb3(%0 : memref<2xf32>)
|
||||
^bb3(%1: memref<2xf32>):
|
||||
"test.crashOp"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
|
||||
"test.op_crash"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
// UNSUPPORTED: system-windows
|
||||
// RUN: mlir-reduce %s -reduction-tree='op-reducer=func traversal-mode=0 test=%S/test.sh'
|
||||
// RUN: mlir-reduce %s -reduction-tree='traversal-mode=0 test=%S/test.sh'
|
||||
|
||||
func @simple1(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
|
||||
cond_br %arg0, ^bb1, ^bb2
|
||||
|
|
|
@ -2,6 +2,6 @@
|
|||
// RUN: not mlir-opt %s -test-mlir-reducer -pass-test function-reducer
|
||||
|
||||
func @test() {
|
||||
"test.crashOp"() : () -> ()
|
||||
"test.op_crash"() : () -> ()
|
||||
return
|
||||
}
|
||||
|
|
|
@ -43,9 +43,6 @@ set(LIBS
|
|||
)
|
||||
|
||||
add_llvm_tool(mlir-reduce
|
||||
OptReductionPass.cpp
|
||||
ReductionNode.cpp
|
||||
ReductionTreePass.cpp
|
||||
mlir-reduce.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
|
|
|
@ -1,107 +0,0 @@
|
|||
//===- ReductionTreePass.cpp - ReductionTreePass Implementation -----------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file defines the Reduction Tree Pass class. It provides a framework for
|
||||
// the implementation of different reduction passes in the MLIR Reduce tool. It
|
||||
// allows for custom specification of the variant generation behavior. It
|
||||
// implements methods that define the different possible traversals of the
|
||||
// reduction tree.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Reducer/ReductionTreePass.h"
|
||||
#include "mlir/Reducer/Passes.h"
|
||||
|
||||
#include "llvm/Support/Allocator.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
static std::unique_ptr<OpReducer> getOpReducer(llvm::StringRef opType) {
|
||||
if (opType == ModuleOp::getOperationName())
|
||||
return std::make_unique<Reducer<ModuleOp>>();
|
||||
else if (opType == FuncOp::getOperationName())
|
||||
return std::make_unique<Reducer<FuncOp>>();
|
||||
llvm_unreachable("Now only supports two built-in ops");
|
||||
}
|
||||
|
||||
void ReductionTreePass::runOnOperation() {
|
||||
ModuleOp module = this->getOperation();
|
||||
std::unique_ptr<OpReducer> reducer = getOpReducer(opReducerName);
|
||||
std::vector<std::pair<int, int>> ranges = {
|
||||
{0, reducer->getNumTargetOps(module)}};
|
||||
|
||||
llvm::SpecificBumpPtrAllocator<ReductionNode> allocator;
|
||||
|
||||
ReductionNode *root = allocator.Allocate();
|
||||
new (root) ReductionNode(nullptr, ranges, allocator);
|
||||
|
||||
ModuleOp golden = module;
|
||||
switch (traversalModeId) {
|
||||
case TraversalMode::SinglePath:
|
||||
golden = findOptimal<ReductionNode::iterator<TraversalMode::SinglePath>>(
|
||||
module, std::move(reducer), root);
|
||||
break;
|
||||
default:
|
||||
llvm_unreachable("Unsupported mode");
|
||||
}
|
||||
|
||||
if (golden != module) {
|
||||
module.getBody()->clear();
|
||||
module.getBody()->getOperations().splice(module.getBody()->begin(),
|
||||
golden.getBody()->getOperations());
|
||||
golden->destroy();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename IteratorType>
|
||||
ModuleOp ReductionTreePass::findOptimal(ModuleOp module,
|
||||
std::unique_ptr<OpReducer> reducer,
|
||||
ReductionNode *root) {
|
||||
Tester test(testerName, testerArgs);
|
||||
std::pair<Tester::Interestingness, size_t> initStatus =
|
||||
test.isInteresting(module);
|
||||
|
||||
if (initStatus.first != Tester::Interestingness::True) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "\nThe original input is not interested");
|
||||
return module;
|
||||
}
|
||||
|
||||
root->update(initStatus);
|
||||
|
||||
ReductionNode *smallestNode = root;
|
||||
ModuleOp golden = module;
|
||||
|
||||
IteratorType iter(root);
|
||||
|
||||
while (iter != IteratorType::end()) {
|
||||
ModuleOp cloneModule = module.clone();
|
||||
|
||||
ReductionNode ¤tNode = *iter;
|
||||
reducer->reduce(cloneModule, currentNode.getRanges());
|
||||
|
||||
std::pair<Tester::Interestingness, size_t> result =
|
||||
test.isInteresting(cloneModule);
|
||||
currentNode.update(result);
|
||||
|
||||
if (result.first == Tester::Interestingness::True &&
|
||||
result.second < smallestNode->getSize()) {
|
||||
smallestNode = ¤tNode;
|
||||
golden = cloneModule;
|
||||
} else {
|
||||
cloneModule->destroy();
|
||||
}
|
||||
|
||||
++iter;
|
||||
}
|
||||
|
||||
return golden;
|
||||
}
|
||||
|
||||
std::unique_ptr<Pass> mlir::createReductionTreePass() {
|
||||
return std::make_unique<ReductionTreePass>();
|
||||
}
|
|
@ -13,22 +13,14 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "mlir/InitAllDialects.h"
|
||||
#include "mlir/InitAllPasses.h"
|
||||
#include "mlir/Parser.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Reducer/OptReductionPass.h"
|
||||
#include "mlir/Reducer/Passes.h"
|
||||
#include "mlir/Reducer/Passes/OpReducer.h"
|
||||
#include "mlir/Reducer/ReductionNode.h"
|
||||
#include "mlir/Reducer/ReductionTreePass.h"
|
||||
#include "mlir/Reducer/Tester.h"
|
||||
#include "mlir/Support/FileUtilities.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include "llvm/Support/InitLLVM.h"
|
||||
#include "llvm/Support/ToolOutputFile.h"
|
||||
|
||||
|
|
Loading…
Reference in New Issue