forked from OSchip/llvm-project
Add new 'createOrFold' methods to FuncBuilder to immediately try to fold an operation after creating it. This can be used to remove operations that are likely to be trivially folded later. Note, these functions only fold operations if all of the folded results are existing values.
PiperOrigin-RevId: 251674299
This commit is contained in:
parent
9fc00cf840
commit
6f5f5a9178
|
@ -19,7 +19,7 @@
|
|||
#define MLIR_IR_BUILDERS_H
|
||||
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
|
@ -297,7 +297,7 @@ public:
|
|||
/// Creates an operation given the fields represented as an OperationState.
|
||||
virtual Operation *createOperation(const OperationState &state);
|
||||
|
||||
/// Create operation of specific op type at the current insertion point.
|
||||
/// Create an operation of specific op type at the current insertion point.
|
||||
template <typename OpTy, typename... Args>
|
||||
OpTy create(Location location, Args... args) {
|
||||
OperationState state(getContext(), location, OpTy::getOperationName());
|
||||
|
@ -308,6 +308,40 @@ public:
|
|||
return result;
|
||||
}
|
||||
|
||||
/// Create an operation of specific op type at the current insertion point,
|
||||
/// and immediately try to fold it. This functions populates 'results' with
|
||||
/// the results after folding the operation.
|
||||
template <typename OpTy, typename... Args>
|
||||
void createOrFold(SmallVectorImpl<Value *> &results, Location location,
|
||||
Args &&... args) {
|
||||
auto op = create<OpTy>(location, std::forward<Args>(args)...);
|
||||
tryFold(op.getOperation(), results);
|
||||
}
|
||||
|
||||
/// Overload to create or fold a single result operation.
|
||||
template <typename OpTy, typename... Args>
|
||||
typename std::enable_if<OpTy::template hasTrait<OpTrait::OneResult>(),
|
||||
Value *>::type
|
||||
createOrFold(Location location, Args &&... args) {
|
||||
SmallVector<Value *, 1> results;
|
||||
createOrFold<OpTy>(results, location, std::forward<Args>(args)...);
|
||||
return results.front();
|
||||
}
|
||||
|
||||
/// Overload to create or fold a zero result operation.
|
||||
template <typename OpTy, typename... Args>
|
||||
typename std::enable_if<OpTy::template hasTrait<OpTrait::ZeroResult>(),
|
||||
OpTy>::type
|
||||
createOrFold(Location location, Args &&... args) {
|
||||
auto op = create<OpTy>(location, std::forward<Args>(args)...);
|
||||
SmallVector<Value *, 0> unused;
|
||||
tryFold(op.getOperation(), unused);
|
||||
|
||||
// Folding cannot remove a zero-result operation, so for convenience we
|
||||
// continue to return it.
|
||||
return op;
|
||||
}
|
||||
|
||||
/// Creates a deep copy of the specified operation, remapping any operands
|
||||
/// that use values outside of the operation using the map that is provided
|
||||
/// ( leaving them alone if no entry is present). Replaces references to
|
||||
|
@ -339,6 +373,10 @@ public:
|
|||
}
|
||||
|
||||
private:
|
||||
/// Attempts to fold the given operation and places new results within
|
||||
/// 'results'.
|
||||
void tryFold(Operation *op, SmallVectorImpl<Value *> &results);
|
||||
|
||||
Region *region;
|
||||
Block *block = nullptr;
|
||||
Block::iterator insertPoint;
|
||||
|
|
|
@ -194,7 +194,7 @@ public:
|
|||
|
||||
/// This hook implements a generalized folder for this operation. Operations
|
||||
/// can implement this to provide simplifications rules that are applied by
|
||||
/// the Builder::foldOrCreate API and the canonicalization pass.
|
||||
/// the Builder::createOrFold API and the canonicalization pass.
|
||||
///
|
||||
/// This is an intentionally limited interface - implementations of this hook
|
||||
/// can only perform the following changes to the operation:
|
||||
|
@ -250,7 +250,7 @@ public:
|
|||
|
||||
/// This hook implements a generalized folder for this operation. Operations
|
||||
/// can implement this to provide simplifications rules that are applied by
|
||||
/// the Builder::foldOrCreate API and the canonicalization pass.
|
||||
/// the Builder::createOrFold API and the canonicalization pass.
|
||||
///
|
||||
/// This is an intentionally limited interface - implementations of this hook
|
||||
/// can only perform the following changes to the operation:
|
||||
|
|
|
@ -106,7 +106,7 @@ public:
|
|||
|
||||
/// This hook implements a generalized folder for this operation. Operations
|
||||
/// can implement this to provide simplifications rules that are applied by
|
||||
/// the Builder::foldOrCreate API and the canonicalization pass.
|
||||
/// the Builder::createOrFold API and the canonicalization pass.
|
||||
///
|
||||
/// This is an intentionally limited interface - implementations of this hook
|
||||
/// can only perform the following changes to the operation:
|
||||
|
|
|
@ -362,3 +362,29 @@ Operation *OpBuilder::createOperation(const OperationState &state) {
|
|||
block->getOperations().insert(insertPoint, op);
|
||||
return op;
|
||||
}
|
||||
|
||||
/// Attempts to fold the given operation and places new results within
|
||||
/// 'results'.
|
||||
void OpBuilder::tryFold(Operation *op, SmallVectorImpl<Value *> &results) {
|
||||
results.reserve(op->getNumResults());
|
||||
SmallVector<OpFoldResult, 4> foldResults;
|
||||
|
||||
// Returns if the given fold result corresponds to a valid existing value.
|
||||
auto isValidValue = [](OpFoldResult result) {
|
||||
return result.dyn_cast<Value *>();
|
||||
};
|
||||
|
||||
// Check if the fold failed, or did not result in only existing values.
|
||||
SmallVector<Attribute, 4> constOperands(op->getNumOperands());
|
||||
if (failed(op->fold(constOperands, foldResults)) || foldResults.empty() ||
|
||||
!llvm::all_of(foldResults, isValidValue)) {
|
||||
// Simply return the existing operation results.
|
||||
results.assign(op->result_begin(), op->result_end());
|
||||
return;
|
||||
}
|
||||
|
||||
// Populate the results with the folded results and remove the original op.
|
||||
llvm::transform(foldResults, std::back_inserter(results),
|
||||
[](OpFoldResult result) { return result.get<Value *>(); });
|
||||
op->erase();
|
||||
}
|
||||
|
|
|
@ -93,31 +93,15 @@ SmallVector<Value *, 8> mlir::linalg::getViewSizes(LinalgOp &linalgOp) {
|
|||
return res;
|
||||
}
|
||||
|
||||
// Folding eagerly is necessary to abide by affine.for static step requirement.
|
||||
// We must propagate constants on the steps as aggressively as possible.
|
||||
// Returns nullptr if folding is not trivially feasible.
|
||||
static Value *tryFold(AffineMap map, ArrayRef<Value *> operands,
|
||||
FunctionConstants &state) {
|
||||
assert(map.getNumResults() == 1 && "single result map expected");
|
||||
auto expr = map.getResult(0);
|
||||
if (auto dim = expr.dyn_cast<AffineDimExpr>())
|
||||
return operands[dim.getPosition()];
|
||||
if (auto sym = expr.dyn_cast<AffineSymbolExpr>())
|
||||
return operands[map.getNumDims() + sym.getPosition()];
|
||||
if (auto cst = expr.dyn_cast<AffineConstantExpr>())
|
||||
return state.getOrCreateIndex(cst.getValue());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static Value *emitOrFoldComposedAffineApply(OpBuilder *b, Location loc,
|
||||
AffineMap map,
|
||||
ArrayRef<Value *> operandsRef,
|
||||
FunctionConstants &state) {
|
||||
SmallVector<Value *, 4> operands(operandsRef.begin(), operandsRef.end());
|
||||
fullyComposeAffineMapAndOperands(&map, &operands);
|
||||
if (auto *v = tryFold(map, operands, state))
|
||||
return v;
|
||||
return b->create<AffineApplyOp>(loc, map, operands);
|
||||
if (auto cst = map.getResult(0).dyn_cast<AffineConstantExpr>())
|
||||
return state.getOrCreateIndex(cst.getValue());
|
||||
return b->createOrFold<AffineApplyOp>(loc, map, operands);
|
||||
}
|
||||
|
||||
SmallVector<Value *, 4>
|
||||
|
|
Loading…
Reference in New Issue