forked from OSchip/llvm-project
Introduce a new operation hook point for implementing simple local
canonicalizations of operations. The ultimate important user of this is going to be a funcBuilder->foldOrCreate<YourOp>(...) API, but for now it is just a more convenient way to write certain classes of canonicalizations (see the change in StandardOps.cpp). NFC. PiperOrigin-RevId: 230770021
This commit is contained in:
parent
451869f394
commit
934b6d125f
|
@ -329,6 +329,9 @@ public:
|
|||
bool constantFold(ArrayRef<Attribute> operands,
|
||||
SmallVectorImpl<Attribute> &results) const;
|
||||
|
||||
/// Attempt to fold this operation using the Op's registered foldHook.
|
||||
bool fold(SmallVectorImpl<Value *> &results);
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Conversions to declared operations like DimOp
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
|
|
@ -214,15 +214,16 @@ private:
|
|||
OperationInst *state;
|
||||
};
|
||||
|
||||
/// This template defines the constantFoldHook as used by AbstractOperation.
|
||||
/// The default implementation uses a general constantFold method that can be
|
||||
/// defined on custom ops which can return multiple results.
|
||||
/// This template defines the constantFoldHook and foldHook as used by
|
||||
/// AbstractOperation.
|
||||
///
|
||||
/// The default implementation uses a general constantFold/fold method that can
|
||||
/// be defined on custom ops which can return multiple results.
|
||||
template <typename ConcreteType, bool isSingleResult, typename = void>
|
||||
class ConstFoldingHook {
|
||||
class FoldingHook {
|
||||
public:
|
||||
/// This hook implements a constant folder for this operation. It returns
|
||||
/// true if folding failed, or returns false and fills in `results` on
|
||||
/// success.
|
||||
/// This is an implementation detail of the constant folder hook for
|
||||
/// AbstractOperation.
|
||||
static bool constantFoldHook(const OperationInst *op,
|
||||
ArrayRef<Attribute> operands,
|
||||
SmallVectorImpl<Attribute> &results) {
|
||||
|
@ -244,18 +245,47 @@ public:
|
|||
MLIRContext *context) const {
|
||||
return true;
|
||||
}
|
||||
|
||||
/// This is an implementation detail of the folder hook for AbstractOperation.
|
||||
static bool foldHook(OperationInst *op, SmallVectorImpl<Value *> &results) {
|
||||
return op->cast<ConcreteType>()->fold(results);
|
||||
}
|
||||
|
||||
/// This hook implements a generalized folder for this operation. Operations
|
||||
/// can implement this to provide simplifications rules that are applied by
|
||||
/// the FuncBuilder::foldOrCreate API and the canonicalization pass.
|
||||
///
|
||||
/// This is an intentionally limited interface - implementations of this hook
|
||||
/// can only perform the following changes to the operation:
|
||||
///
|
||||
/// 1. They can leave the operation alone and without changing the IR, and
|
||||
/// return true.
|
||||
/// 2. They can mutate the operation in place, without changing anything else
|
||||
/// in the IR. In this case, return false.
|
||||
/// 3. They can return a list of existing values that can be used instead of
|
||||
/// the operation. In this case, fill in the results list and return
|
||||
/// false. The caller will remove the operation and use those results
|
||||
/// instead.
|
||||
///
|
||||
/// This allows expression of some simple in-place canonicalizations (e.g.
|
||||
/// "x+0 -> x", "min(x,y,x,z) -> min(x,y,z)", "x+y-x -> y", etc), but does
|
||||
/// not allow for canonicalizations that need to introduce new operations, not
|
||||
/// even constants (e.g. "x-x -> 0" cannot be expressed).
|
||||
///
|
||||
/// If not overridden, this fallback implementation always fails to fold.
|
||||
///
|
||||
bool fold(SmallVectorImpl<Value *> &results) { return true; }
|
||||
};
|
||||
|
||||
/// This template specialization defines the constantFoldHook as used by
|
||||
/// AbstractOperation for single-result operations. This gives the hook a nicer
|
||||
/// signature that is easier to implement.
|
||||
/// This template specialization defines the constantFoldHook and foldHook as
|
||||
/// used by AbstractOperation for single-result operations. This gives the hook
|
||||
/// a nicer signature that is easier to implement.
|
||||
template <typename ConcreteType, bool isSingleResult>
|
||||
class ConstFoldingHook<ConcreteType, isSingleResult,
|
||||
typename std::enable_if<isSingleResult>::type> {
|
||||
class FoldingHook<ConcreteType, isSingleResult,
|
||||
typename std::enable_if<isSingleResult>::type> {
|
||||
public:
|
||||
/// This hook implements a constant folder for this operation. It returns
|
||||
/// true if folding failed, or returns false and fills in `results` on
|
||||
/// success.
|
||||
/// This is an implementation detail of the constant folder hook for
|
||||
/// AbstractOperation.
|
||||
static bool constantFoldHook(const OperationInst *op,
|
||||
ArrayRef<Attribute> operands,
|
||||
SmallVectorImpl<Attribute> &results) {
|
||||
|
@ -267,6 +297,54 @@ public:
|
|||
results.push_back(result);
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Op implementations can implement this hook. It should attempt to constant
|
||||
/// fold this operation with the specified constant operand values - the
|
||||
/// elements in "operands" will correspond directly to the operands of the
|
||||
/// operation, but may be null if non-constant. If constant folding is
|
||||
/// successful, this returns a non-null attribute, otherwise it returns null
|
||||
/// on failure.
|
||||
///
|
||||
/// If not overridden, this fallback implementation always fails to fold.
|
||||
///
|
||||
Attribute constantFold(ArrayRef<Attribute> operands,
|
||||
MLIRContext *context) const {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
/// This is an implementation detail of the folder hook for AbstractOperation.
|
||||
static bool foldHook(OperationInst *op, SmallVectorImpl<Value *> &results) {
|
||||
auto *result = op->cast<ConcreteType>()->fold();
|
||||
if (!result)
|
||||
return true;
|
||||
if (result != op->getResult(0))
|
||||
results.push_back(result);
|
||||
return false;
|
||||
}
|
||||
|
||||
/// This hook implements a generalized folder for this operation. Operations
|
||||
/// can implement this to provide simplifications rules that are applied by
|
||||
/// the FuncBuilder::foldOrCreate API and the canonicalization pass.
|
||||
///
|
||||
/// This is an intentionally limited interface - implementations of this hook
|
||||
/// can only perform the following changes to the operation:
|
||||
///
|
||||
/// 1. They can leave the operation alone and without changing the IR, and
|
||||
/// return nullptr.
|
||||
/// 2. They can mutate the operation in place, without changing anything else
|
||||
/// in the IR. In this case, return the operation itself.
|
||||
/// 3. They can return an existing SSA value that can be used instead of
|
||||
/// the operation. In this case, return that value. The caller will
|
||||
/// remove the operation and use that result instead.
|
||||
///
|
||||
/// This allows expression of some simple in-place canonicalizations (e.g.
|
||||
/// "x+0 -> x", "min(x,y,x,z) -> min(x,y,z)", "x+y-x -> y", etc), but does
|
||||
/// not allow for canonicalizations that need to introduce new operations, not
|
||||
/// even constants (e.g. "x-x -> 0" cannot be expressed).
|
||||
///
|
||||
/// If not overridden, this fallback implementation always fails to fold.
|
||||
///
|
||||
Value *fold() { return nullptr; }
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -521,20 +599,6 @@ public:
|
|||
static bool verifyTrait(const OperationInst *op) {
|
||||
return impl::verifyOneResult(op);
|
||||
}
|
||||
|
||||
/// Op implementations can implement this hook. It should attempt to constant
|
||||
/// fold this operation with the specified constant operand values - the
|
||||
/// elements in "operands" will correspond directly to the operands of the
|
||||
/// operation, but may be null if non-constant. If constant folding is
|
||||
/// successful, this returns a non-null attribute, otherwise it returns null
|
||||
/// on failure.
|
||||
///
|
||||
/// If not overridden, this fallback implementation always fails to fold.
|
||||
///
|
||||
Attribute constantFold(ArrayRef<Attribute> operands,
|
||||
MLIRContext *context) const {
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
/// This class provides the API for ops that are known to have a specified
|
||||
|
@ -783,7 +847,7 @@ public:
|
|||
template <typename ConcreteType, template <typename T> class... Traits>
|
||||
class Op : public OpState,
|
||||
public Traits<ConcreteType>...,
|
||||
public ConstFoldingHook<
|
||||
public FoldingHook<
|
||||
ConcreteType,
|
||||
typelist_contains<OpTrait::OneResult<ConcreteType>, OpState,
|
||||
Traits<ConcreteType>...>::value> {
|
||||
|
|
|
@ -97,6 +97,28 @@ public:
|
|||
ArrayRef<Attribute> operands,
|
||||
SmallVectorImpl<Attribute> &results);
|
||||
|
||||
/// This hook implements a generalized folder for this operation. Operations
|
||||
/// can implement this to provide simplifications rules that are applied by
|
||||
/// the FuncBuilder::foldOrCreate API and the canonicalization pass.
|
||||
///
|
||||
/// This is an intentionally limited interface - implementations of this hook
|
||||
/// can only perform the following changes to the operation:
|
||||
///
|
||||
/// 1. They can leave the operation alone and without changing the IR, and
|
||||
/// return true.
|
||||
/// 2. They can mutate the operation in place, without changing anything else
|
||||
/// in the IR. In this case, return false.
|
||||
/// 3. They can return a list of existing values that can be used instead of
|
||||
/// the operation. In this case, fill in the results list and return
|
||||
/// false. The caller will remove the operation and use those results
|
||||
/// instead.
|
||||
///
|
||||
/// This allows expression of some simple in-place canonicalizations (e.g.
|
||||
/// "x+0 -> x", "min(x,y,x,z) -> min(x,y,z)", "x+y-x -> y", etc), but does
|
||||
/// not allow for canonicalizations that need to introduce new operations, not
|
||||
/// even constants (e.g. "x-x -> 0" cannot be expressed).
|
||||
bool (&foldHook)(OperationInst *op, SmallVectorImpl<Value *> &results);
|
||||
|
||||
/// This hook returns any canonicalization pattern rewrites that the operation
|
||||
/// supports, for use by the canonicalization pass.
|
||||
void (&getCanonicalizationPatterns)(OwningRewritePatternList &results,
|
||||
|
@ -118,7 +140,7 @@ public:
|
|||
return AbstractOperation(
|
||||
T::getOperationName(), dialect, T::getOperationProperties(),
|
||||
T::isClassFor, T::parseAssembly, T::printAssembly, T::verifyInvariants,
|
||||
T::constantFoldHook, T::getCanonicalizationPatterns);
|
||||
T::constantFoldHook, T::foldHook, T::getCanonicalizationPatterns);
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -131,11 +153,13 @@ private:
|
|||
bool (&constantFoldHook)(const OperationInst *op,
|
||||
ArrayRef<Attribute> operands,
|
||||
SmallVectorImpl<Attribute> &results),
|
||||
bool (&foldHook)(OperationInst *op, SmallVectorImpl<Value *> &results),
|
||||
void (&getCanonicalizationPatterns)(OwningRewritePatternList &results,
|
||||
MLIRContext *context))
|
||||
: name(name), dialect(dialect), isClassFor(isClassFor),
|
||||
parseAssembly(parseAssembly), printAssembly(printAssembly),
|
||||
verifyInvariants(verifyInvariants), constantFoldHook(constantFoldHook),
|
||||
foldHook(foldHook),
|
||||
getCanonicalizationPatterns(getCanonicalizationPatterns),
|
||||
opProperties(opProperties) {}
|
||||
|
||||
|
|
|
@ -404,6 +404,9 @@ class Op<string mnemonic, list<OpProperty> props = []> {
|
|||
// Whether this op has a constant folder.
|
||||
bit hasConstantFolder = 0b0;
|
||||
|
||||
// Whether this op has a folder.
|
||||
bit hasFolder = 0b0;
|
||||
|
||||
// Op properties.
|
||||
list<OpProperty> properties = props;
|
||||
}
|
||||
|
|
|
@ -89,7 +89,7 @@ def AddFOp : FloatArithmeticOp<"addf"> {
|
|||
|
||||
def AddIOp : IntArithmeticOp<"addi", [Commutative]> {
|
||||
let summary = "integer addition operation";
|
||||
let hasCanonicalizer = 0b1;
|
||||
let hasFolder = 1;
|
||||
let hasConstantFolder = 0b1;
|
||||
}
|
||||
|
||||
|
|
|
@ -532,13 +532,11 @@ bool OperationInst::constantFold(ArrayRef<Attribute> operands,
|
|||
if (auto *abstractOp = getAbstractOperation()) {
|
||||
// If we have a registered operation definition matching this one, use it to
|
||||
// try to constant fold the operation.
|
||||
if (!abstractOp->constantFoldHook(llvm::cast<OperationInst>(this), operands,
|
||||
results))
|
||||
if (!abstractOp->constantFoldHook(this, operands, results))
|
||||
return false;
|
||||
|
||||
// Otherwise, fall back on the dialect hook to handle it.
|
||||
return abstractOp->dialect.constantFoldHook(llvm::cast<OperationInst>(this),
|
||||
operands, results);
|
||||
return abstractOp->dialect.constantFoldHook(this, operands, results);
|
||||
}
|
||||
|
||||
// If this operation hasn't been registered or doesn't have abstract
|
||||
|
@ -546,13 +544,23 @@ bool OperationInst::constantFold(ArrayRef<Attribute> operands,
|
|||
auto opName = getName().getStringRef();
|
||||
auto dialectPrefix = opName.split('.').first;
|
||||
if (auto *dialect = getContext()->getRegisteredDialect(dialectPrefix)) {
|
||||
return dialect->constantFoldHook(llvm::cast<OperationInst>(this), operands,
|
||||
results);
|
||||
return dialect->constantFoldHook(this, operands, results);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Attempt to fold this operation using the Op's registered foldHook.
|
||||
bool OperationInst::fold(SmallVectorImpl<Value *> &results) {
|
||||
if (auto *abstractOp = getAbstractOperation()) {
|
||||
// If we have a registered operation definition matching this one, use it to
|
||||
// try to constant fold the operation.
|
||||
if (!abstractOp->foldHook(this, results))
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Emit an error with the op name prefixed, like "'dim' op " which is
|
||||
/// convenient for verifiers.
|
||||
bool OperationInst::emitOpError(const Twine &message) const {
|
||||
|
|
|
@ -128,30 +128,12 @@ Attribute AddIOp::constantFold(ArrayRef<Attribute> operands,
|
|||
[](APInt a, APInt b) { return a + b; });
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// addi(x, 0) -> x
|
||||
///
|
||||
struct SimplifyAddX0 : public RewritePattern {
|
||||
SimplifyAddX0(MLIRContext *context)
|
||||
: RewritePattern(AddIOp::getOperationName(), 1, context) {}
|
||||
Value *AddIOp::fold() {
|
||||
/// addi(x, 0) -> x
|
||||
if (matchPattern(getOperand(1), m_Zero()))
|
||||
return getOperand(0);
|
||||
|
||||
PatternMatchResult match(OperationInst *op) const override {
|
||||
auto addi = op->cast<AddIOp>();
|
||||
|
||||
if (matchPattern(addi->getOperand(1), m_Zero()))
|
||||
return matchSuccess();
|
||||
|
||||
return matchFailure();
|
||||
}
|
||||
void rewrite(OperationInst *op, PatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOp(op, op->getOperand(0));
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace.
|
||||
|
||||
void AddIOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
results.push_back(std::make_unique<SimplifyAddX0>(context));
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -130,6 +130,7 @@ private:
|
|||
void GreedyPatternRewriteDriver::simplifyFunction() {
|
||||
// These are scratch vectors used in the constant folding loop below.
|
||||
SmallVector<Attribute, 8> operandConstants, resultConstants;
|
||||
SmallVector<Value *, 8> originalOperands, resultValues;
|
||||
|
||||
while (!worklist.empty()) {
|
||||
auto *op = popFromWorklist();
|
||||
|
@ -195,6 +196,14 @@ void GreedyPatternRewriteDriver::simplifyFunction() {
|
|||
operandConstants.push_back(operandCst);
|
||||
}
|
||||
|
||||
// If this is a commutative binary operation with a constant on the left
|
||||
// side move it to the right side.
|
||||
if (operandConstants.size() == 2 && operandConstants[0] &&
|
||||
!operandConstants[1] && op->isCommutative()) {
|
||||
std::swap(op->getInstOperand(0), op->getInstOperand(1));
|
||||
std::swap(operandConstants[0], operandConstants[1]);
|
||||
}
|
||||
|
||||
// If constant folding was successful, create the result constants, RAUW the
|
||||
// operation and remove it.
|
||||
resultConstants.clear();
|
||||
|
@ -233,13 +242,41 @@ void GreedyPatternRewriteDriver::simplifyFunction() {
|
|||
continue;
|
||||
}
|
||||
|
||||
// If this is a commutative binary operation with a constant on the left
|
||||
// side move it to the right side.
|
||||
if (operandConstants.size() == 2 && operandConstants[0] &&
|
||||
!operandConstants[1] && op->isCommutative()) {
|
||||
auto *newLHS = op->getOperand(1);
|
||||
op->setOperand(1, op->getOperand(0));
|
||||
op->setOperand(0, newLHS);
|
||||
// Otherwise see if we can use the generic folder API to simplify the
|
||||
// operation.
|
||||
originalOperands.assign(op->operand_begin(), op->operand_end());
|
||||
resultValues.clear();
|
||||
if (!op->fold(resultValues)) {
|
||||
// If the result was an in-place simplification (e.g. max(x,x,y) ->
|
||||
// max(x,y)) then add the original operands to the worklist so we can make
|
||||
// sure to revisit them.
|
||||
if (resultValues.empty()) {
|
||||
// TODO: Walk the original operand list dropping them as we go. If any
|
||||
// of them drop to zero uses, then add them to the worklist to allow
|
||||
// them to be deleted as dead.
|
||||
} else {
|
||||
// Otherwise, the operation is simplified away completely.
|
||||
assert(resultValues.size() == op->getNumResults());
|
||||
|
||||
// Add all the users of the operation to the worklist so we make sure to
|
||||
// revisit them.
|
||||
//
|
||||
// TODO: Add a result->getUsers() iterator.
|
||||
for (unsigned i = 0, e = resultValues.size(); i != e; ++i) {
|
||||
auto *res = op->getResult(i);
|
||||
if (res->use_empty()) // ignore dead uses.
|
||||
continue;
|
||||
|
||||
for (auto &operand : op->getResult(i)->getUses()) {
|
||||
if (auto *op = dyn_cast<OperationInst>(operand.getOwner()))
|
||||
addToWorklist(op);
|
||||
}
|
||||
res->replaceAllUsesWith(resultValues[i]);
|
||||
}
|
||||
}
|
||||
|
||||
op->erase();
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check to see if we have any patterns that match this node.
|
||||
|
|
|
@ -104,8 +104,8 @@ public:
|
|||
// Emit method declaration for the getCanonicalizationPatterns() interface.
|
||||
void emitCanonicalizationPatterns();
|
||||
|
||||
// Emit the constant folder method for the operation.
|
||||
void emitConstantFolder();
|
||||
// Emit the folder methods for the operation.
|
||||
void emitFolders();
|
||||
|
||||
// Emit the parser for the operation.
|
||||
void emitParser();
|
||||
|
@ -165,7 +165,7 @@ void OpEmitter::emit(const Record &def, raw_ostream &os) {
|
|||
emitter.emitVerifier();
|
||||
emitter.emitAttrGetters();
|
||||
emitter.emitCanonicalizationPatterns();
|
||||
emitter.emitConstantFolder();
|
||||
emitter.emitFolders();
|
||||
|
||||
os << "private:\n friend class ::mlir::OperationInst;\n"
|
||||
<< " explicit " << emitter.op.cppClassName()
|
||||
|
@ -333,16 +333,25 @@ void OpEmitter::emitCanonicalizationPatterns() {
|
|||
<< "OwningRewritePatternList &results, MLIRContext* context);\n";
|
||||
}
|
||||
|
||||
void OpEmitter::emitConstantFolder() {
|
||||
if (!def.getValueAsBit("hasConstantFolder"))
|
||||
return;
|
||||
if (def.getValueAsListOfDefs("returnTypes").size() == 1) {
|
||||
os << " Attribute constantFold(ArrayRef<Attribute> operands,\n"
|
||||
" MLIRContext *context) const;\n";
|
||||
} else {
|
||||
os << " bool constantFold(ArrayRef<Attribute> operands,\n"
|
||||
<< " SmallVectorImpl<Attribute> &results,\n"
|
||||
<< " MLIRContext *context) const;\n";
|
||||
void OpEmitter::emitFolders() {
|
||||
bool hasSingleResult = def.getValueAsListOfDefs("returnTypes").size() == 1;
|
||||
if (def.getValueAsBit("hasConstantFolder")) {
|
||||
if (hasSingleResult) {
|
||||
os << " Attribute constantFold(ArrayRef<Attribute> operands,\n"
|
||||
" MLIRContext *context) const;\n";
|
||||
} else {
|
||||
os << " bool constantFold(ArrayRef<Attribute> operands,\n"
|
||||
<< " SmallVectorImpl<Attribute> &results,\n"
|
||||
<< " MLIRContext *context) const;\n";
|
||||
}
|
||||
}
|
||||
|
||||
if (def.getValueAsBit("hasFolder")) {
|
||||
if (hasSingleResult) {
|
||||
os << " Value *fold();\n";
|
||||
} else {
|
||||
os << " bool fold(SmallVectorImpl<Value *> &results);\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue