forked from OSchip/llvm-project
Add constant folding and binary operator reassociation to the canonicalize
pass, build up the worklist infra in anticipation of improving the pattern matcher to match more than one node. PiperOrigin-RevId: 217330579
This commit is contained in:
parent
58168e476e
commit
80e884a9f8
|
@ -149,6 +149,9 @@ public:
|
|||
state->setAttr(name, value);
|
||||
}
|
||||
|
||||
/// Return true if there are no users of any results of this operation.
|
||||
bool use_empty() const { return state->use_empty(); }
|
||||
|
||||
/// Emit an error about fatal conditions with this operation, reporting up to
|
||||
/// any diagnostic handlers that may be listening. NOTE: This may terminate
|
||||
/// the containing application, only use when the IR is in an inconsistent
|
||||
|
@ -482,6 +485,9 @@ public:
|
|||
|
||||
Type *getType() const { return getResult()->getType(); }
|
||||
|
||||
/// Return true if there are no users of any results of this operation.
|
||||
bool use_empty() const { return getResult()->use_empty(); }
|
||||
|
||||
/// Replace all uses of 'this' value with the new value, updating anything in
|
||||
/// the IR that uses 'this' to use the other value instead. When this returns
|
||||
/// there are zero uses of 'this'.
|
||||
|
|
|
@ -96,6 +96,9 @@ public:
|
|||
const_result_iterator result_end() const;
|
||||
llvm::iterator_range<const_result_iterator> getResults() const;
|
||||
|
||||
/// Return true if there are no users of any results of this operation.
|
||||
bool use_empty() const;
|
||||
|
||||
// Attributes. Operations may optionally carry a list of attributes that
|
||||
// associate constants to names. Attributes may be dynamically added and
|
||||
// removed over the lifetime of an operation.
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
//===- mlir/Pass.h - Base classes for compiler passes -----------*- C++ -*-===//
|
||||
//===- Pass.h - Base classes for compiler passes ----------------*- C++ -*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
//===- mlir/PatternMatch.h - Base classes for pattern match -----*- C++ -*-===//
|
||||
//===- PatternMatch.h - Base classes for pattern match ----------*- C++ -*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
|
@ -19,8 +19,6 @@
|
|||
#define MLIR_PATTERN_MATCH_H
|
||||
|
||||
#include "mlir/IR/OperationSupport.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
|
|
|
@ -131,6 +131,14 @@ SSAValue *Operation::getResult(unsigned idx) {
|
|||
return cast<OperationStmt>(this)->getResult(idx);
|
||||
}
|
||||
|
||||
/// Return true if there are no users of any results of this operation.
|
||||
bool Operation::use_empty() const {
|
||||
for (auto *result : getResults())
|
||||
if (!result->use_empty())
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
ArrayRef<NamedAttribute> Operation::getAttrs() const {
|
||||
if (!attrs)
|
||||
return {};
|
||||
|
|
|
@ -26,6 +26,7 @@
|
|||
#include "mlir/Transforms/Pass.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include "mlir/Transforms/PatternMatch.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
|
||||
#include <memory>
|
||||
using namespace mlir;
|
||||
|
@ -87,6 +88,39 @@ struct Canonicalizer : public FunctionPass {
|
|||
|
||||
void simplifyFunction(std::vector<Operation *> &worklist,
|
||||
MLFuncBuilder &builder);
|
||||
|
||||
void addToWorklist(Operation *op) {
|
||||
worklistMap[op] = worklist.size();
|
||||
worklist.push_back(op);
|
||||
}
|
||||
|
||||
Operation *popFromWorklist() {
|
||||
auto *op = worklist.back();
|
||||
worklist.pop_back();
|
||||
|
||||
// This operation is no longer in the worklist, keep worklistMap up to date.
|
||||
if (op)
|
||||
worklistMap.erase(op);
|
||||
return op;
|
||||
}
|
||||
|
||||
/// If the specified operation is in the worklist, remove it. If not, this is
|
||||
/// a no-op.
|
||||
void removeFromWorklist(Operation *op) {
|
||||
auto it = worklistMap.find(op);
|
||||
if (it != worklistMap.end()) {
|
||||
assert(worklist[it->second] == op && "malformed worklist data structure");
|
||||
worklist[it->second] = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
/// The worklist for this transformation keeps track of the operations that
|
||||
/// need to be revisited, plus their index in the worklist. This allows us to
|
||||
/// efficiently remove operations from the worklist when they are removed even
|
||||
/// if they aren't the root of a pattern.
|
||||
std::vector<Operation *> worklist;
|
||||
DenseMap<Operation *, unsigned> worklistMap;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
|
@ -96,10 +130,9 @@ PassResult Canonicalizer::runOnCFGFunction(CFGFunction *f) {
|
|||
}
|
||||
|
||||
PassResult Canonicalizer::runOnMLFunction(MLFunction *f) {
|
||||
std::vector<Operation *> worklist;
|
||||
worklist.reserve(64);
|
||||
|
||||
f->walk([&](OperationStmt *stmt) { worklist.push_back(stmt); });
|
||||
f->walk([&](OperationStmt *stmt) { addToWorklist(stmt); });
|
||||
|
||||
MLFuncBuilder builder(f);
|
||||
simplifyFunction(worklist, builder);
|
||||
|
@ -114,15 +147,69 @@ void Canonicalizer::simplifyFunction(std::vector<Operation *> &worklist,
|
|||
|
||||
PatternMatcher matcher({new SimplifyXMinusX(builder.getContext())});
|
||||
|
||||
// These are scratch vectors used in the constant folding loop below.
|
||||
SmallVector<Attribute *, 8> operandConstants, resultConstants;
|
||||
|
||||
while (!worklist.empty()) {
|
||||
auto *op = worklist.back();
|
||||
worklist.pop_back();
|
||||
auto *op = popFromWorklist();
|
||||
|
||||
// TODO: If no side effects, and operation has no users, then it is
|
||||
// trivially dead - remove it.
|
||||
// Nulls get added to the worklist when operations are removed, ignore them.
|
||||
if (op == nullptr)
|
||||
continue;
|
||||
|
||||
// TODO: Call the constant folding hook on this operation, and canonicalize
|
||||
// constants into the entry node.
|
||||
// If the operation has no side effects, and no users, then it is trivially
|
||||
// dead - remove it.
|
||||
if (op->hasNoSideEffect() && op->use_empty()) {
|
||||
// FIXME: Generalize to support CFG statements as well.
|
||||
cast<OperationStmt>(op)->eraseFromBlock();
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check to see if any operands to the instruction is constant and whether
|
||||
// the operation knows how to constant fold itself.
|
||||
operandConstants.clear();
|
||||
for (auto *operand : op->getOperands()) {
|
||||
Attribute *operandCst = nullptr;
|
||||
if (auto *operandOp = operand->getDefiningOperation()) {
|
||||
if (auto operandConstantOp = operandOp->getAs<ConstantOp>())
|
||||
operandCst = operandConstantOp->getValue();
|
||||
}
|
||||
operandConstants.push_back(operandCst);
|
||||
}
|
||||
|
||||
// If constant folding was successful, create the result constants, RAUW the
|
||||
// operation and remove it.
|
||||
resultConstants.clear();
|
||||
if (!op->constantFold(operandConstants, resultConstants)) {
|
||||
// TODO: Put these in the entry block and unique them.
|
||||
FuncBuilder cstBuilder(builder);
|
||||
cstBuilder.setInsertionPoint(op);
|
||||
|
||||
for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
|
||||
auto *res = op->getResult(i);
|
||||
if (res->use_empty()) // ignore dead uses.
|
||||
continue;
|
||||
|
||||
auto cst = cstBuilder.create<ConstantOp>(
|
||||
op->getLoc(), resultConstants[i], res->getType());
|
||||
res->replaceAllUsesWith(cst);
|
||||
}
|
||||
|
||||
assert(op->hasNoSideEffect() && "Constant folded op with side effects?");
|
||||
|
||||
// FIXME: Generalize to support CFG statements as well.
|
||||
cast<OperationStmt>(op)->eraseFromBlock();
|
||||
continue;
|
||||
}
|
||||
|
||||
// If this is an associative binary operation with a constant on the LHS,
|
||||
// move it to the right side.
|
||||
if (operandConstants.size() == 2 && operandConstants[0] &&
|
||||
!operandConstants[1]) {
|
||||
auto *newLHS = op->getOperand(1);
|
||||
op->setOperand(1, op->getOperand(0));
|
||||
op->setOperand(0, newLHS);
|
||||
}
|
||||
|
||||
// Check to see if we have any patterns that match this node.
|
||||
auto match = matcher.findMatch(op);
|
||||
|
@ -131,6 +218,8 @@ void Canonicalizer::simplifyFunction(std::vector<Operation *> &worklist,
|
|||
|
||||
// TODO: Need to be a bit trickier to make sure new instructions get into
|
||||
// the worklist.
|
||||
// TODO: Need to be careful to remove instructions from the worklist when
|
||||
// they are eliminated by the replace method.
|
||||
match.first->rewrite(op, std::move(match.second), builder);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,10 +1,31 @@
|
|||
// RUN: mlir-opt %s -canonicalize | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @test_subi_zero
|
||||
mlfunc @test_subi_zero(%x: i32) -> i32 {
|
||||
mlfunc @test_subi_zero(%arg0: i32) -> i32 {
|
||||
// CHECK: %c0_i32 = constant 0 : i32
|
||||
// CHECK-NEXT: return %c0
|
||||
%y = subi %x, %x : i32
|
||||
%y = subi %arg0, %arg0 : i32
|
||||
return %y: i32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: mlfunc @dim
|
||||
mlfunc @dim(%arg0 : tensor<8x4xf32>) -> index {
|
||||
|
||||
// CHECK: %c4 = constant 4 : index
|
||||
%0 = dim %arg0, 1 : tensor<8x4xf32>
|
||||
|
||||
// CHECK-NEXT: return %c4
|
||||
return %0 : index
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_associative
|
||||
mlfunc @test_associative(%arg0: i32) -> i32 {
|
||||
// CHECK: %c42_i32 = constant 42 : i32
|
||||
// CHECK-NEXT: %0 = addi %arg0, %c42_i32 : i32
|
||||
// CHECK-NEXT: return %0
|
||||
|
||||
%c42_i32 = constant 42 : i32
|
||||
%y = addi %c42_i32, %arg0 : i32
|
||||
return %y: i32
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue