Introduce a new operation hook point for implementing simple local

canonicalizations of operations.  The ultimate important user of this is
going to be a funcBuilder->foldOrCreate<YourOp>(...) API, but for now it
is just a more convenient way to write certain classes of canonicalizations
(see the change in StandardOps.cpp).

NFC.

PiperOrigin-RevId: 230770021
This commit is contained in:
Chris Lattner 2019-01-24 12:34:00 -08:00 committed by jpienaar
parent 451869f394
commit 934b6d125f
9 changed files with 211 additions and 81 deletions

View File

@ -329,6 +329,9 @@ public:
bool constantFold(ArrayRef<Attribute> operands,
SmallVectorImpl<Attribute> &results) const;
/// Attempt to fold this operation using the Op's registered foldHook.
bool fold(SmallVectorImpl<Value *> &results);
//===--------------------------------------------------------------------===//
// Conversions to declared operations like DimOp
//===--------------------------------------------------------------------===//

View File

@ -214,15 +214,16 @@ private:
OperationInst *state;
};
/// This template defines the constantFoldHook as used by AbstractOperation.
/// The default implementation uses a general constantFold method that can be
/// defined on custom ops which can return multiple results.
/// This template defines the constantFoldHook and foldHook as used by
/// AbstractOperation.
///
/// The default implementation uses a general constantFold/fold method that can
/// be defined on custom ops which can return multiple results.
template <typename ConcreteType, bool isSingleResult, typename = void>
class ConstFoldingHook {
class FoldingHook {
public:
/// This hook implements a constant folder for this operation. It returns
/// true if folding failed, or returns false and fills in `results` on
/// success.
/// This is an implementation detail of the constant folder hook for
/// AbstractOperation.
static bool constantFoldHook(const OperationInst *op,
ArrayRef<Attribute> operands,
SmallVectorImpl<Attribute> &results) {
@ -244,18 +245,47 @@ public:
MLIRContext *context) const {
return true;
}
/// This is an implementation detail of the folder hook for AbstractOperation.
static bool foldHook(OperationInst *op, SmallVectorImpl<Value *> &results) {
return op->cast<ConcreteType>()->fold(results);
}
/// This hook implements a generalized folder for this operation. Operations
/// can implement this to provide simplifications rules that are applied by
/// the FuncBuilder::foldOrCreate API and the canonicalization pass.
///
/// This is an intentionally limited interface - implementations of this hook
/// can only perform the following changes to the operation:
///
/// 1. They can leave the operation alone and without changing the IR, and
/// return true.
/// 2. They can mutate the operation in place, without changing anything else
/// in the IR. In this case, return false.
/// 3. They can return a list of existing values that can be used instead of
/// the operation. In this case, fill in the results list and return
/// false. The caller will remove the operation and use those results
/// instead.
///
/// This allows expression of some simple in-place canonicalizations (e.g.
/// "x+0 -> x", "min(x,y,x,z) -> min(x,y,z)", "x+y-x -> y", etc), but does
/// not allow for canonicalizations that need to introduce new operations, not
/// even constants (e.g. "x-x -> 0" cannot be expressed).
///
/// If not overridden, this fallback implementation always fails to fold.
///
bool fold(SmallVectorImpl<Value *> &results) { return true; }
};
/// This template specialization defines the constantFoldHook as used by
/// AbstractOperation for single-result operations. This gives the hook a nicer
/// signature that is easier to implement.
/// This template specialization defines the constantFoldHook and foldHook as
/// used by AbstractOperation for single-result operations. This gives the hook
/// a nicer signature that is easier to implement.
template <typename ConcreteType, bool isSingleResult>
class ConstFoldingHook<ConcreteType, isSingleResult,
typename std::enable_if<isSingleResult>::type> {
class FoldingHook<ConcreteType, isSingleResult,
typename std::enable_if<isSingleResult>::type> {
public:
/// This hook implements a constant folder for this operation. It returns
/// true if folding failed, or returns false and fills in `results` on
/// success.
/// This is an implementation detail of the constant folder hook for
/// AbstractOperation.
static bool constantFoldHook(const OperationInst *op,
ArrayRef<Attribute> operands,
SmallVectorImpl<Attribute> &results) {
@ -267,6 +297,54 @@ public:
results.push_back(result);
return false;
}
/// Op implementations can implement this hook. It should attempt to constant
/// fold this operation with the specified constant operand values - the
/// elements in "operands" will correspond directly to the operands of the
/// operation, but may be null if non-constant. If constant folding is
/// successful, this returns a non-null attribute, otherwise it returns null
/// on failure.
///
/// If not overridden, this fallback implementation always fails to fold.
///
Attribute constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) const {
return nullptr;
}
/// This is an implementation detail of the folder hook for AbstractOperation.
static bool foldHook(OperationInst *op, SmallVectorImpl<Value *> &results) {
auto *result = op->cast<ConcreteType>()->fold();
if (!result)
return true;
if (result != op->getResult(0))
results.push_back(result);
return false;
}
/// This hook implements a generalized folder for this operation. Operations
/// can implement this to provide simplifications rules that are applied by
/// the FuncBuilder::foldOrCreate API and the canonicalization pass.
///
/// This is an intentionally limited interface - implementations of this hook
/// can only perform the following changes to the operation:
///
/// 1. They can leave the operation alone and without changing the IR, and
/// return nullptr.
/// 2. They can mutate the operation in place, without changing anything else
/// in the IR. In this case, return the operation itself.
/// 3. They can return an existing SSA value that can be used instead of
/// the operation. In this case, return that value. The caller will
/// remove the operation and use that result instead.
///
/// This allows expression of some simple in-place canonicalizations (e.g.
/// "x+0 -> x", "min(x,y,x,z) -> min(x,y,z)", "x+y-x -> y", etc), but does
/// not allow for canonicalizations that need to introduce new operations, not
/// even constants (e.g. "x-x -> 0" cannot be expressed).
///
/// If not overridden, this fallback implementation always fails to fold.
///
Value *fold() { return nullptr; }
};
//===----------------------------------------------------------------------===//
@ -521,20 +599,6 @@ public:
static bool verifyTrait(const OperationInst *op) {
return impl::verifyOneResult(op);
}
/// Op implementations can implement this hook. It should attempt to constant
/// fold this operation with the specified constant operand values - the
/// elements in "operands" will correspond directly to the operands of the
/// operation, but may be null if non-constant. If constant folding is
/// successful, this returns a non-null attribute, otherwise it returns null
/// on failure.
///
/// If not overridden, this fallback implementation always fails to fold.
///
Attribute constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) const {
return nullptr;
}
};
/// This class provides the API for ops that are known to have a specified
@ -783,7 +847,7 @@ public:
template <typename ConcreteType, template <typename T> class... Traits>
class Op : public OpState,
public Traits<ConcreteType>...,
public ConstFoldingHook<
public FoldingHook<
ConcreteType,
typelist_contains<OpTrait::OneResult<ConcreteType>, OpState,
Traits<ConcreteType>...>::value> {

View File

@ -97,6 +97,28 @@ public:
ArrayRef<Attribute> operands,
SmallVectorImpl<Attribute> &results);
/// This hook implements a generalized folder for this operation. Operations
/// can implement this to provide simplifications rules that are applied by
/// the FuncBuilder::foldOrCreate API and the canonicalization pass.
///
/// This is an intentionally limited interface - implementations of this hook
/// can only perform the following changes to the operation:
///
/// 1. They can leave the operation alone and without changing the IR, and
/// return true.
/// 2. They can mutate the operation in place, without changing anything else
/// in the IR. In this case, return false.
/// 3. They can return a list of existing values that can be used instead of
/// the operation. In this case, fill in the results list and return
/// false. The caller will remove the operation and use those results
/// instead.
///
/// This allows expression of some simple in-place canonicalizations (e.g.
/// "x+0 -> x", "min(x,y,x,z) -> min(x,y,z)", "x+y-x -> y", etc), but does
/// not allow for canonicalizations that need to introduce new operations, not
/// even constants (e.g. "x-x -> 0" cannot be expressed).
bool (&foldHook)(OperationInst *op, SmallVectorImpl<Value *> &results);
/// This hook returns any canonicalization pattern rewrites that the operation
/// supports, for use by the canonicalization pass.
void (&getCanonicalizationPatterns)(OwningRewritePatternList &results,
@ -118,7 +140,7 @@ public:
return AbstractOperation(
T::getOperationName(), dialect, T::getOperationProperties(),
T::isClassFor, T::parseAssembly, T::printAssembly, T::verifyInvariants,
T::constantFoldHook, T::getCanonicalizationPatterns);
T::constantFoldHook, T::foldHook, T::getCanonicalizationPatterns);
}
private:
@ -131,11 +153,13 @@ private:
bool (&constantFoldHook)(const OperationInst *op,
ArrayRef<Attribute> operands,
SmallVectorImpl<Attribute> &results),
bool (&foldHook)(OperationInst *op, SmallVectorImpl<Value *> &results),
void (&getCanonicalizationPatterns)(OwningRewritePatternList &results,
MLIRContext *context))
: name(name), dialect(dialect), isClassFor(isClassFor),
parseAssembly(parseAssembly), printAssembly(printAssembly),
verifyInvariants(verifyInvariants), constantFoldHook(constantFoldHook),
foldHook(foldHook),
getCanonicalizationPatterns(getCanonicalizationPatterns),
opProperties(opProperties) {}

View File

@ -404,6 +404,9 @@ class Op<string mnemonic, list<OpProperty> props = []> {
// Whether this op has a constant folder.
bit hasConstantFolder = 0b0;
// Whether this op has a folder.
bit hasFolder = 0b0;
// Op properties.
list<OpProperty> properties = props;
}

View File

@ -89,7 +89,7 @@ def AddFOp : FloatArithmeticOp<"addf"> {
def AddIOp : IntArithmeticOp<"addi", [Commutative]> {
let summary = "integer addition operation";
let hasCanonicalizer = 0b1;
let hasFolder = 1;
let hasConstantFolder = 0b1;
}

View File

@ -532,13 +532,11 @@ bool OperationInst::constantFold(ArrayRef<Attribute> operands,
if (auto *abstractOp = getAbstractOperation()) {
// If we have a registered operation definition matching this one, use it to
// try to constant fold the operation.
if (!abstractOp->constantFoldHook(llvm::cast<OperationInst>(this), operands,
results))
if (!abstractOp->constantFoldHook(this, operands, results))
return false;
// Otherwise, fall back on the dialect hook to handle it.
return abstractOp->dialect.constantFoldHook(llvm::cast<OperationInst>(this),
operands, results);
return abstractOp->dialect.constantFoldHook(this, operands, results);
}
// If this operation hasn't been registered or doesn't have abstract
@ -546,13 +544,23 @@ bool OperationInst::constantFold(ArrayRef<Attribute> operands,
auto opName = getName().getStringRef();
auto dialectPrefix = opName.split('.').first;
if (auto *dialect = getContext()->getRegisteredDialect(dialectPrefix)) {
return dialect->constantFoldHook(llvm::cast<OperationInst>(this), operands,
results);
return dialect->constantFoldHook(this, operands, results);
}
return true;
}
/// Attempt to fold this operation using the Op's registered foldHook.
bool OperationInst::fold(SmallVectorImpl<Value *> &results) {
if (auto *abstractOp = getAbstractOperation()) {
// If we have a registered operation definition matching this one, use it to
// try to constant fold the operation.
if (!abstractOp->foldHook(this, results))
return false;
}
return true;
}
/// Emit an error with the op name prefixed, like "'dim' op " which is
/// convenient for verifiers.
bool OperationInst::emitOpError(const Twine &message) const {

View File

@ -128,30 +128,12 @@ Attribute AddIOp::constantFold(ArrayRef<Attribute> operands,
[](APInt a, APInt b) { return a + b; });
}
namespace {
/// addi(x, 0) -> x
///
struct SimplifyAddX0 : public RewritePattern {
SimplifyAddX0(MLIRContext *context)
: RewritePattern(AddIOp::getOperationName(), 1, context) {}
Value *AddIOp::fold() {
/// addi(x, 0) -> x
if (matchPattern(getOperand(1), m_Zero()))
return getOperand(0);
PatternMatchResult match(OperationInst *op) const override {
auto addi = op->cast<AddIOp>();
if (matchPattern(addi->getOperand(1), m_Zero()))
return matchSuccess();
return matchFailure();
}
void rewrite(OperationInst *op, PatternRewriter &rewriter) const override {
rewriter.replaceOp(op, op->getOperand(0));
}
};
} // end anonymous namespace.
void AddIOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.push_back(std::make_unique<SimplifyAddX0>(context));
return nullptr;
}
//===----------------------------------------------------------------------===//

View File

@ -130,6 +130,7 @@ private:
void GreedyPatternRewriteDriver::simplifyFunction() {
// These are scratch vectors used in the constant folding loop below.
SmallVector<Attribute, 8> operandConstants, resultConstants;
SmallVector<Value *, 8> originalOperands, resultValues;
while (!worklist.empty()) {
auto *op = popFromWorklist();
@ -195,6 +196,14 @@ void GreedyPatternRewriteDriver::simplifyFunction() {
operandConstants.push_back(operandCst);
}
// If this is a commutative binary operation with a constant on the left
// side move it to the right side.
if (operandConstants.size() == 2 && operandConstants[0] &&
!operandConstants[1] && op->isCommutative()) {
std::swap(op->getInstOperand(0), op->getInstOperand(1));
std::swap(operandConstants[0], operandConstants[1]);
}
// If constant folding was successful, create the result constants, RAUW the
// operation and remove it.
resultConstants.clear();
@ -233,13 +242,41 @@ void GreedyPatternRewriteDriver::simplifyFunction() {
continue;
}
// If this is a commutative binary operation with a constant on the left
// side move it to the right side.
if (operandConstants.size() == 2 && operandConstants[0] &&
!operandConstants[1] && op->isCommutative()) {
auto *newLHS = op->getOperand(1);
op->setOperand(1, op->getOperand(0));
op->setOperand(0, newLHS);
// Otherwise see if we can use the generic folder API to simplify the
// operation.
originalOperands.assign(op->operand_begin(), op->operand_end());
resultValues.clear();
if (!op->fold(resultValues)) {
// If the result was an in-place simplification (e.g. max(x,x,y) ->
// max(x,y)) then add the original operands to the worklist so we can make
// sure to revisit them.
if (resultValues.empty()) {
// TODO: Walk the original operand list dropping them as we go. If any
// of them drop to zero uses, then add them to the worklist to allow
// them to be deleted as dead.
} else {
// Otherwise, the operation is simplified away completely.
assert(resultValues.size() == op->getNumResults());
// Add all the users of the operation to the worklist so we make sure to
// revisit them.
//
// TODO: Add a result->getUsers() iterator.
for (unsigned i = 0, e = resultValues.size(); i != e; ++i) {
auto *res = op->getResult(i);
if (res->use_empty()) // ignore dead uses.
continue;
for (auto &operand : op->getResult(i)->getUses()) {
if (auto *op = dyn_cast<OperationInst>(operand.getOwner()))
addToWorklist(op);
}
res->replaceAllUsesWith(resultValues[i]);
}
}
op->erase();
continue;
}
// Check to see if we have any patterns that match this node.

View File

@ -104,8 +104,8 @@ public:
// Emit method declaration for the getCanonicalizationPatterns() interface.
void emitCanonicalizationPatterns();
// Emit the constant folder method for the operation.
void emitConstantFolder();
// Emit the folder methods for the operation.
void emitFolders();
// Emit the parser for the operation.
void emitParser();
@ -165,7 +165,7 @@ void OpEmitter::emit(const Record &def, raw_ostream &os) {
emitter.emitVerifier();
emitter.emitAttrGetters();
emitter.emitCanonicalizationPatterns();
emitter.emitConstantFolder();
emitter.emitFolders();
os << "private:\n friend class ::mlir::OperationInst;\n"
<< " explicit " << emitter.op.cppClassName()
@ -333,16 +333,25 @@ void OpEmitter::emitCanonicalizationPatterns() {
<< "OwningRewritePatternList &results, MLIRContext* context);\n";
}
void OpEmitter::emitConstantFolder() {
if (!def.getValueAsBit("hasConstantFolder"))
return;
if (def.getValueAsListOfDefs("returnTypes").size() == 1) {
os << " Attribute constantFold(ArrayRef<Attribute> operands,\n"
" MLIRContext *context) const;\n";
} else {
os << " bool constantFold(ArrayRef<Attribute> operands,\n"
<< " SmallVectorImpl<Attribute> &results,\n"
<< " MLIRContext *context) const;\n";
void OpEmitter::emitFolders() {
bool hasSingleResult = def.getValueAsListOfDefs("returnTypes").size() == 1;
if (def.getValueAsBit("hasConstantFolder")) {
if (hasSingleResult) {
os << " Attribute constantFold(ArrayRef<Attribute> operands,\n"
" MLIRContext *context) const;\n";
} else {
os << " bool constantFold(ArrayRef<Attribute> operands,\n"
<< " SmallVectorImpl<Attribute> &results,\n"
<< " MLIRContext *context) const;\n";
}
}
if (def.getValueAsBit("hasFolder")) {
if (hasSingleResult) {
os << " Value *fold();\n";
} else {
os << " bool fold(SmallVectorImpl<Value *> &results);\n";
}
}
}