From 80e884a9f83cf75a229e0a057e4a60313c7103e5 Mon Sep 17 00:00:00 2001 From: Chris Lattner Date: Tue, 16 Oct 2018 09:31:45 -0700 Subject: [PATCH] Add constant folding and binary operator reassociation to the canonicalize pass, build up the worklist infra in anticipation of improving the pattern matcher to match more than one node. PiperOrigin-RevId: 217330579 --- mlir/include/mlir/IR/OpDefinition.h | 6 ++ mlir/include/mlir/IR/Operation.h | 3 + mlir/include/mlir/Transforms/Pass.h | 2 +- mlir/include/mlir/Transforms/PatternMatch.h | 4 +- mlir/lib/IR/Operation.cpp | 8 ++ mlir/lib/Transforms/Canonicalizer.cpp | 105 ++++++++++++++++++-- mlir/test/Transforms/canonicalize.mlir | 25 ++++- 7 files changed, 139 insertions(+), 14 deletions(-) diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index 4435ae0d86ee..bfbf25f6545e 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -149,6 +149,9 @@ public: state->setAttr(name, value); } + /// Return true if there are no users of any results of this operation. + bool use_empty() const { return state->use_empty(); } + /// Emit an error about fatal conditions with this operation, reporting up to /// any diagnostic handlers that may be listening. NOTE: This may terminate /// the containing application, only use when the IR is in an inconsistent @@ -482,6 +485,9 @@ public: Type *getType() const { return getResult()->getType(); } + /// Return true if there are no users of any results of this operation. + bool use_empty() const { return getResult()->use_empty(); } + /// Replace all uses of 'this' value with the new value, updating anything in /// the IR that uses 'this' to use the other value instead. When this returns /// there are zero uses of 'this'. diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h index cb389bd4d807..ea0f68a405b1 100644 --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -96,6 +96,9 @@ public: const_result_iterator result_end() const; llvm::iterator_range getResults() const; + /// Return true if there are no users of any results of this operation. + bool use_empty() const; + // Attributes. Operations may optionally carry a list of attributes that // associate constants to names. Attributes may be dynamically added and // removed over the lifetime of an operation. diff --git a/mlir/include/mlir/Transforms/Pass.h b/mlir/include/mlir/Transforms/Pass.h index fadac58c5383..f18bc65be699 100644 --- a/mlir/include/mlir/Transforms/Pass.h +++ b/mlir/include/mlir/Transforms/Pass.h @@ -1,4 +1,4 @@ -//===- mlir/Pass.h - Base classes for compiler passes -----------*- C++ -*-===// +//===- Pass.h - Base classes for compiler passes ----------------*- C++ -*-===// // // Copyright 2019 The MLIR Authors. // diff --git a/mlir/include/mlir/Transforms/PatternMatch.h b/mlir/include/mlir/Transforms/PatternMatch.h index 4b0b35fc18ca..ac479f94ce36 100644 --- a/mlir/include/mlir/Transforms/PatternMatch.h +++ b/mlir/include/mlir/Transforms/PatternMatch.h @@ -1,4 +1,4 @@ -//===- mlir/PatternMatch.h - Base classes for pattern match -----*- C++ -*-===// +//===- PatternMatch.h - Base classes for pattern match ----------*- C++ -*-===// // // Copyright 2019 The MLIR Authors. // @@ -19,8 +19,6 @@ #define MLIR_PATTERN_MATCH_H #include "mlir/IR/OperationSupport.h" -#include "mlir/Support/LLVM.h" -#include "llvm/ADT/ArrayRef.h" namespace mlir { diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index 8714687d9c4f..e7a012a8a580 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -131,6 +131,14 @@ SSAValue *Operation::getResult(unsigned idx) { return cast(this)->getResult(idx); } +/// Return true if there are no users of any results of this operation. +bool Operation::use_empty() const { + for (auto *result : getResults()) + if (!result->use_empty()) + return false; + return true; +} + ArrayRef Operation::getAttrs() const { if (!attrs) return {}; diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp index dabcdecf3746..ace8110e5800 100644 --- a/mlir/lib/Transforms/Canonicalizer.cpp +++ b/mlir/lib/Transforms/Canonicalizer.cpp @@ -26,6 +26,7 @@ #include "mlir/Transforms/Pass.h" #include "mlir/Transforms/Passes.h" #include "mlir/Transforms/PatternMatch.h" +#include "llvm/ADT/DenseMap.h" #include using namespace mlir; @@ -87,6 +88,39 @@ struct Canonicalizer : public FunctionPass { void simplifyFunction(std::vector &worklist, MLFuncBuilder &builder); + + void addToWorklist(Operation *op) { + worklistMap[op] = worklist.size(); + worklist.push_back(op); + } + + Operation *popFromWorklist() { + auto *op = worklist.back(); + worklist.pop_back(); + + // This operation is no longer in the worklist, keep worklistMap up to date. + if (op) + worklistMap.erase(op); + return op; + } + + /// If the specified operation is in the worklist, remove it. If not, this is + /// a no-op. + void removeFromWorklist(Operation *op) { + auto it = worklistMap.find(op); + if (it != worklistMap.end()) { + assert(worklist[it->second] == op && "malformed worklist data structure"); + worklist[it->second] = nullptr; + } + } + +private: + /// The worklist for this transformation keeps track of the operations that + /// need to be revisited, plus their index in the worklist. This allows us to + /// efficiently remove operations from the worklist when they are removed even + /// if they aren't the root of a pattern. + std::vector worklist; + DenseMap worklistMap; }; } // end anonymous namespace @@ -96,10 +130,9 @@ PassResult Canonicalizer::runOnCFGFunction(CFGFunction *f) { } PassResult Canonicalizer::runOnMLFunction(MLFunction *f) { - std::vector worklist; worklist.reserve(64); - f->walk([&](OperationStmt *stmt) { worklist.push_back(stmt); }); + f->walk([&](OperationStmt *stmt) { addToWorklist(stmt); }); MLFuncBuilder builder(f); simplifyFunction(worklist, builder); @@ -114,15 +147,69 @@ void Canonicalizer::simplifyFunction(std::vector &worklist, PatternMatcher matcher({new SimplifyXMinusX(builder.getContext())}); + // These are scratch vectors used in the constant folding loop below. + SmallVector operandConstants, resultConstants; + while (!worklist.empty()) { - auto *op = worklist.back(); - worklist.pop_back(); + auto *op = popFromWorklist(); - // TODO: If no side effects, and operation has no users, then it is - // trivially dead - remove it. + // Nulls get added to the worklist when operations are removed, ignore them. + if (op == nullptr) + continue; - // TODO: Call the constant folding hook on this operation, and canonicalize - // constants into the entry node. + // If the operation has no side effects, and no users, then it is trivially + // dead - remove it. + if (op->hasNoSideEffect() && op->use_empty()) { + // FIXME: Generalize to support CFG statements as well. + cast(op)->eraseFromBlock(); + continue; + } + + // Check to see if any operands to the instruction is constant and whether + // the operation knows how to constant fold itself. + operandConstants.clear(); + for (auto *operand : op->getOperands()) { + Attribute *operandCst = nullptr; + if (auto *operandOp = operand->getDefiningOperation()) { + if (auto operandConstantOp = operandOp->getAs()) + operandCst = operandConstantOp->getValue(); + } + operandConstants.push_back(operandCst); + } + + // If constant folding was successful, create the result constants, RAUW the + // operation and remove it. + resultConstants.clear(); + if (!op->constantFold(operandConstants, resultConstants)) { + // TODO: Put these in the entry block and unique them. + FuncBuilder cstBuilder(builder); + cstBuilder.setInsertionPoint(op); + + for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) { + auto *res = op->getResult(i); + if (res->use_empty()) // ignore dead uses. + continue; + + auto cst = cstBuilder.create( + op->getLoc(), resultConstants[i], res->getType()); + res->replaceAllUsesWith(cst); + } + + assert(op->hasNoSideEffect() && "Constant folded op with side effects?"); + + // FIXME: Generalize to support CFG statements as well. + cast(op)->eraseFromBlock(); + continue; + } + + // If this is an associative binary operation with a constant on the LHS, + // move it to the right side. + if (operandConstants.size() == 2 && operandConstants[0] && + !operandConstants[1]) { + auto *newLHS = op->getOperand(1); + op->setOperand(1, op->getOperand(0)); + op->setOperand(0, newLHS); + } // Check to see if we have any patterns that match this node. auto match = matcher.findMatch(op); @@ -131,6 +218,8 @@ void Canonicalizer::simplifyFunction(std::vector &worklist, // TODO: Need to be a bit trickier to make sure new instructions get into // the worklist. + // TODO: Need to be careful to remove instructions from the worklist when + // they are eliminated by the replace method. match.first->rewrite(op, std::move(match.second), builder); } } diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir index 5f309151b8db..c277096fc40e 100644 --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -1,10 +1,31 @@ // RUN: mlir-opt %s -canonicalize | FileCheck %s // CHECK-LABEL: @test_subi_zero -mlfunc @test_subi_zero(%x: i32) -> i32 { +mlfunc @test_subi_zero(%arg0: i32) -> i32 { // CHECK: %c0_i32 = constant 0 : i32 // CHECK-NEXT: return %c0 - %y = subi %x, %x : i32 + %y = subi %arg0, %arg0 : i32 + return %y: i32 +} + +// CHECK-LABEL: mlfunc @dim +mlfunc @dim(%arg0 : tensor<8x4xf32>) -> index { + + // CHECK: %c4 = constant 4 : index + %0 = dim %arg0, 1 : tensor<8x4xf32> + + // CHECK-NEXT: return %c4 + return %0 : index +} + +// CHECK-LABEL: @test_associative +mlfunc @test_associative(%arg0: i32) -> i32 { + // CHECK: %c42_i32 = constant 42 : i32 + // CHECK-NEXT: %0 = addi %arg0, %c42_i32 : i32 + // CHECK-NEXT: return %0 + + %c42_i32 = constant 42 : i32 + %y = addi %c42_i32, %arg0 : i32 return %y: i32 }