Add constant folding and binary operator reassociation to the canonicalize

pass, build up the worklist infra in anticipation of improving the pattern
matcher to match more than one node.

PiperOrigin-RevId: 217330579
This commit is contained in:
Chris Lattner 2018-10-16 09:31:45 -07:00 committed by jpienaar
parent 58168e476e
commit 80e884a9f8
7 changed files with 139 additions and 14 deletions

View File

@ -149,6 +149,9 @@ public:
state->setAttr(name, value);
}
/// Return true if there are no users of any results of this operation.
bool use_empty() const { return state->use_empty(); }
/// Emit an error about fatal conditions with this operation, reporting up to
/// any diagnostic handlers that may be listening. NOTE: This may terminate
/// the containing application, only use when the IR is in an inconsistent
@ -482,6 +485,9 @@ public:
Type *getType() const { return getResult()->getType(); }
/// Return true if there are no users of any results of this operation.
bool use_empty() const { return getResult()->use_empty(); }
/// Replace all uses of 'this' value with the new value, updating anything in
/// the IR that uses 'this' to use the other value instead. When this returns
/// there are zero uses of 'this'.

View File

@ -96,6 +96,9 @@ public:
const_result_iterator result_end() const;
llvm::iterator_range<const_result_iterator> getResults() const;
/// Return true if there are no users of any results of this operation.
bool use_empty() const;
// Attributes. Operations may optionally carry a list of attributes that
// associate constants to names. Attributes may be dynamically added and
// removed over the lifetime of an operation.

View File

@ -1,4 +1,4 @@
//===- mlir/Pass.h - Base classes for compiler passes -----------*- C++ -*-===//
//===- Pass.h - Base classes for compiler passes ----------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//

View File

@ -1,4 +1,4 @@
//===- mlir/PatternMatch.h - Base classes for pattern match -----*- C++ -*-===//
//===- PatternMatch.h - Base classes for pattern match ----------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
@ -19,8 +19,6 @@
#define MLIR_PATTERN_MATCH_H
#include "mlir/IR/OperationSupport.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/ArrayRef.h"
namespace mlir {

View File

@ -131,6 +131,14 @@ SSAValue *Operation::getResult(unsigned idx) {
return cast<OperationStmt>(this)->getResult(idx);
}
/// Return true if there are no users of any results of this operation.
bool Operation::use_empty() const {
for (auto *result : getResults())
if (!result->use_empty())
return false;
return true;
}
ArrayRef<NamedAttribute> Operation::getAttrs() const {
if (!attrs)
return {};

View File

@ -26,6 +26,7 @@
#include "mlir/Transforms/Pass.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Transforms/PatternMatch.h"
#include "llvm/ADT/DenseMap.h"
#include <memory>
using namespace mlir;
@ -87,6 +88,39 @@ struct Canonicalizer : public FunctionPass {
void simplifyFunction(std::vector<Operation *> &worklist,
MLFuncBuilder &builder);
void addToWorklist(Operation *op) {
worklistMap[op] = worklist.size();
worklist.push_back(op);
}
Operation *popFromWorklist() {
auto *op = worklist.back();
worklist.pop_back();
// This operation is no longer in the worklist, keep worklistMap up to date.
if (op)
worklistMap.erase(op);
return op;
}
/// If the specified operation is in the worklist, remove it. If not, this is
/// a no-op.
void removeFromWorklist(Operation *op) {
auto it = worklistMap.find(op);
if (it != worklistMap.end()) {
assert(worklist[it->second] == op && "malformed worklist data structure");
worklist[it->second] = nullptr;
}
}
private:
/// The worklist for this transformation keeps track of the operations that
/// need to be revisited, plus their index in the worklist. This allows us to
/// efficiently remove operations from the worklist when they are removed even
/// if they aren't the root of a pattern.
std::vector<Operation *> worklist;
DenseMap<Operation *, unsigned> worklistMap;
};
} // end anonymous namespace
@ -96,10 +130,9 @@ PassResult Canonicalizer::runOnCFGFunction(CFGFunction *f) {
}
PassResult Canonicalizer::runOnMLFunction(MLFunction *f) {
std::vector<Operation *> worklist;
worklist.reserve(64);
f->walk([&](OperationStmt *stmt) { worklist.push_back(stmt); });
f->walk([&](OperationStmt *stmt) { addToWorklist(stmt); });
MLFuncBuilder builder(f);
simplifyFunction(worklist, builder);
@ -114,15 +147,69 @@ void Canonicalizer::simplifyFunction(std::vector<Operation *> &worklist,
PatternMatcher matcher({new SimplifyXMinusX(builder.getContext())});
// These are scratch vectors used in the constant folding loop below.
SmallVector<Attribute *, 8> operandConstants, resultConstants;
while (!worklist.empty()) {
auto *op = worklist.back();
worklist.pop_back();
auto *op = popFromWorklist();
// TODO: If no side effects, and operation has no users, then it is
// trivially dead - remove it.
// Nulls get added to the worklist when operations are removed, ignore them.
if (op == nullptr)
continue;
// TODO: Call the constant folding hook on this operation, and canonicalize
// constants into the entry node.
// If the operation has no side effects, and no users, then it is trivially
// dead - remove it.
if (op->hasNoSideEffect() && op->use_empty()) {
// FIXME: Generalize to support CFG statements as well.
cast<OperationStmt>(op)->eraseFromBlock();
continue;
}
// Check to see if any operands to the instruction is constant and whether
// the operation knows how to constant fold itself.
operandConstants.clear();
for (auto *operand : op->getOperands()) {
Attribute *operandCst = nullptr;
if (auto *operandOp = operand->getDefiningOperation()) {
if (auto operandConstantOp = operandOp->getAs<ConstantOp>())
operandCst = operandConstantOp->getValue();
}
operandConstants.push_back(operandCst);
}
// If constant folding was successful, create the result constants, RAUW the
// operation and remove it.
resultConstants.clear();
if (!op->constantFold(operandConstants, resultConstants)) {
// TODO: Put these in the entry block and unique them.
FuncBuilder cstBuilder(builder);
cstBuilder.setInsertionPoint(op);
for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
auto *res = op->getResult(i);
if (res->use_empty()) // ignore dead uses.
continue;
auto cst = cstBuilder.create<ConstantOp>(
op->getLoc(), resultConstants[i], res->getType());
res->replaceAllUsesWith(cst);
}
assert(op->hasNoSideEffect() && "Constant folded op with side effects?");
// FIXME: Generalize to support CFG statements as well.
cast<OperationStmt>(op)->eraseFromBlock();
continue;
}
// If this is an associative binary operation with a constant on the LHS,
// move it to the right side.
if (operandConstants.size() == 2 && operandConstants[0] &&
!operandConstants[1]) {
auto *newLHS = op->getOperand(1);
op->setOperand(1, op->getOperand(0));
op->setOperand(0, newLHS);
}
// Check to see if we have any patterns that match this node.
auto match = matcher.findMatch(op);
@ -131,6 +218,8 @@ void Canonicalizer::simplifyFunction(std::vector<Operation *> &worklist,
// TODO: Need to be a bit trickier to make sure new instructions get into
// the worklist.
// TODO: Need to be careful to remove instructions from the worklist when
// they are eliminated by the replace method.
match.first->rewrite(op, std::move(match.second), builder);
}
}

View File

@ -1,10 +1,31 @@
// RUN: mlir-opt %s -canonicalize | FileCheck %s
// CHECK-LABEL: @test_subi_zero
mlfunc @test_subi_zero(%x: i32) -> i32 {
mlfunc @test_subi_zero(%arg0: i32) -> i32 {
// CHECK: %c0_i32 = constant 0 : i32
// CHECK-NEXT: return %c0
%y = subi %x, %x : i32
%y = subi %arg0, %arg0 : i32
return %y: i32
}
// CHECK-LABEL: mlfunc @dim
mlfunc @dim(%arg0 : tensor<8x4xf32>) -> index {
// CHECK: %c4 = constant 4 : index
%0 = dim %arg0, 1 : tensor<8x4xf32>
// CHECK-NEXT: return %c4
return %0 : index
}
// CHECK-LABEL: @test_associative
mlfunc @test_associative(%arg0: i32) -> i32 {
// CHECK: %c42_i32 = constant 42 : i32
// CHECK-NEXT: %0 = addi %arg0, %c42_i32 : i32
// CHECK-NEXT: return %0
%c42_i32 = constant 42 : i32
%y = addi %c42_i32, %arg0 : i32
return %y: i32
}