forked from OSchip/llvm-project
Update the OperationFolder to find a valid insertion point when materializing constants.
The OperationFolder currently just inserts into the entry block of a Function, but regions may be isolated above, i.e. explicit capture only, and blindly inserting constants may break the invariants of these regions. PiperOrigin-RevId: 254987796
This commit is contained in:
parent
2628641b23
commit
66ed7d6d83
|
@ -33,21 +33,8 @@ class Value;
|
||||||
|
|
||||||
/// A utility class for folding operations, and unifying duplicated constants
|
/// A utility class for folding operations, and unifying duplicated constants
|
||||||
/// generated along the way.
|
/// generated along the way.
|
||||||
///
|
|
||||||
/// To make sure constants properly dominate all their uses, constants are
|
|
||||||
/// moved to the beginning of the entry block of the function when tracked by
|
|
||||||
/// this class.
|
|
||||||
class OperationFolder {
|
class OperationFolder {
|
||||||
public:
|
public:
|
||||||
/// Constructs an instance for managing constants in the given function `f`.
|
|
||||||
/// Constants tracked by this instance will be moved to the entry block of
|
|
||||||
/// `f`. The insertion always happens at the very top of the entry block.
|
|
||||||
///
|
|
||||||
/// This instance does not proactively walk the operations inside `f`;
|
|
||||||
/// instead, users must invoke the following methods to manually handle each
|
|
||||||
/// operation of interest.
|
|
||||||
OperationFolder(Function *f) : function(f) {}
|
|
||||||
|
|
||||||
/// Tries to perform folding on the given `op`, including unifying
|
/// Tries to perform folding on the given `op`, including unifying
|
||||||
/// deduplicated constants. If successful, replaces `op`'s uses with
|
/// deduplicated constants. If successful, replaces `op`'s uses with
|
||||||
/// folded results, and returns success. `preReplaceAction` is invoked on `op`
|
/// folded results, and returns success. `preReplaceAction` is invoked on `op`
|
||||||
|
@ -67,7 +54,7 @@ public:
|
||||||
void notifyRemoval(Operation *op);
|
void notifyRemoval(Operation *op);
|
||||||
|
|
||||||
/// Create an operation of specific op type with the given builder,
|
/// Create an operation of specific op type with the given builder,
|
||||||
/// and immediately try to fold it. This functions populates 'results' with
|
/// and immediately try to fold it. This function populates 'results' with
|
||||||
/// the results after folding the operation.
|
/// the results after folding the operation.
|
||||||
template <typename OpTy, typename... Args>
|
template <typename OpTy, typename... Args>
|
||||||
void create(OpBuilder &builder, SmallVectorImpl<Value *> &results,
|
void create(OpBuilder &builder, SmallVectorImpl<Value *> &results,
|
||||||
|
@ -104,6 +91,13 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
/// This map keeps track of uniqued constants by dialect, attribute, and type.
|
||||||
|
/// A constant operation materializes an attribute with a type. Dialects may
|
||||||
|
/// generate different constants with the same input attribute and type, so we
|
||||||
|
/// also need to track per-dialect.
|
||||||
|
using ConstantMap =
|
||||||
|
DenseMap<std::tuple<Dialect *, Attribute, Type>, Operation *>;
|
||||||
|
|
||||||
/// Tries to perform folding on the given `op`. If successful, populates
|
/// Tries to perform folding on the given `op`. If successful, populates
|
||||||
/// `results` with the results of the folding.
|
/// `results` with the results of the folding.
|
||||||
LogicalResult tryToFold(Operation *op, SmallVectorImpl<Value *> &results,
|
LogicalResult tryToFold(Operation *op, SmallVectorImpl<Value *> &results,
|
||||||
|
@ -112,18 +106,13 @@ private:
|
||||||
|
|
||||||
/// Try to get or create a new constant entry. On success this returns the
|
/// Try to get or create a new constant entry. On success this returns the
|
||||||
/// constant operation, nullptr otherwise.
|
/// constant operation, nullptr otherwise.
|
||||||
Operation *tryGetOrCreateConstant(Dialect *dialect, OpBuilder &builder,
|
Operation *tryGetOrCreateConstant(ConstantMap &uniquedConstants,
|
||||||
|
Dialect *dialect, OpBuilder &builder,
|
||||||
Attribute value, Type type, Location loc);
|
Attribute value, Type type, Location loc);
|
||||||
|
|
||||||
/// The function where we are managing constant.
|
/// A mapping between an insertion region and the constants that have been
|
||||||
Function *function;
|
/// created within it.
|
||||||
|
DenseMap<Region *, ConstantMap> foldScopes;
|
||||||
/// This map keeps track of uniqued constants by dialect, attribute, and type.
|
|
||||||
/// A constant operation materializes an attribute with a type. Dialects may
|
|
||||||
/// generate different constants with the same input attribute and type, so we
|
|
||||||
/// also need to track per-dialect.
|
|
||||||
DenseMap<std::tuple<Dialect *, Attribute, Type>, Operation *>
|
|
||||||
uniquedConstants;
|
|
||||||
|
|
||||||
/// This map tracks all of the dialects that an operation is referenced by;
|
/// This map tracks all of the dialects that an operation is referenced by;
|
||||||
/// given that many dialects may generate the same constant.
|
/// given that many dialects may generate the same constant.
|
||||||
|
|
|
@ -1016,7 +1016,7 @@ void mlir::linalg::emitScalarImplementation(
|
||||||
ScopedContext scope(b, loc);
|
ScopedContext scope(b, loc);
|
||||||
auto *op = linalgOp.getOperation();
|
auto *op = linalgOp.getOperation();
|
||||||
if (auto copyOp = dyn_cast<CopyOp>(op)) {
|
if (auto copyOp = dyn_cast<CopyOp>(op)) {
|
||||||
OperationFolder state(op->getFunction());
|
OperationFolder state;
|
||||||
auto inputIvs = permuteIvs(parallelIvs, copyOp.inputPermutation(), state);
|
auto inputIvs = permuteIvs(parallelIvs, copyOp.inputPermutation(), state);
|
||||||
auto outputIvs = permuteIvs(parallelIvs, copyOp.outputPermutation(), state);
|
auto outputIvs = permuteIvs(parallelIvs, copyOp.outputPermutation(), state);
|
||||||
SmallVector<IndexHandle, 8> iivs(inputIvs.begin(), inputIvs.end());
|
SmallVector<IndexHandle, 8> iivs(inputIvs.begin(), inputIvs.end());
|
||||||
|
|
|
@ -210,7 +210,7 @@ static bool isStructurallyFusableProducer(LinalgOp producer, Value *readView,
|
||||||
}
|
}
|
||||||
|
|
||||||
static void fuseLinalgOps(Function &f, ArrayRef<int64_t> tileSizes) {
|
static void fuseLinalgOps(Function &f, ArrayRef<int64_t> tileSizes) {
|
||||||
OperationFolder state(&f);
|
OperationFolder state;
|
||||||
DenseSet<Operation *> eraseSet;
|
DenseSet<Operation *> eraseSet;
|
||||||
|
|
||||||
// 1. Record the linalg ops so we can traverse them in reverse order.
|
// 1. Record the linalg ops so we can traverse them in reverse order.
|
||||||
|
|
|
@ -105,7 +105,7 @@ struct LowerLinalgToLoopsPass : public FunctionPass<LowerLinalgToLoopsPass> {
|
||||||
|
|
||||||
void LowerLinalgToLoopsPass::runOnFunction() {
|
void LowerLinalgToLoopsPass::runOnFunction() {
|
||||||
auto &f = getFunction();
|
auto &f = getFunction();
|
||||||
OperationFolder state(&f);
|
OperationFolder state;
|
||||||
f.walk<LinalgOp>([&state](LinalgOp linalgOp) {
|
f.walk<LinalgOp>([&state](LinalgOp linalgOp) {
|
||||||
emitLinalgOpAsLoops(linalgOp, state);
|
emitLinalgOpAsLoops(linalgOp, state);
|
||||||
linalgOp.getOperation()->erase();
|
linalgOp.getOperation()->erase();
|
||||||
|
|
|
@ -260,7 +260,7 @@ mlir::linalg::tileLinalgOp(LinalgOp op, ArrayRef<int64_t> tileSizes,
|
||||||
}
|
}
|
||||||
|
|
||||||
static void tileLinalgOps(Function &f, ArrayRef<int64_t> tileSizes) {
|
static void tileLinalgOps(Function &f, ArrayRef<int64_t> tileSizes) {
|
||||||
OperationFolder state(&f);
|
OperationFolder state;
|
||||||
f.walk<LinalgOp>([tileSizes, &state](LinalgOp op) {
|
f.walk<LinalgOp>([tileSizes, &state](LinalgOp op) {
|
||||||
auto opLoopsPair = tileLinalgOp(op, tileSizes, state);
|
auto opLoopsPair = tileLinalgOp(op, tileSizes, state);
|
||||||
// If tiling occurred successfully, erase old op.
|
// If tiling occurred successfully, erase old op.
|
||||||
|
|
|
@ -29,6 +29,47 @@
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
|
/// Given an operation, find the parent region that folded constants should be
|
||||||
|
/// inserted into.
|
||||||
|
static Region *getInsertionRegion(Operation *op) {
|
||||||
|
while (Region *region = op->getContainingRegion()) {
|
||||||
|
// Insert in this region for any of the following scenarios:
|
||||||
|
// * The parent is not an operation, i.e. is a Function.
|
||||||
|
// * The parent is unregistered, or is known to be isolated from above.
|
||||||
|
// * The parent is a top-level operation.
|
||||||
|
auto *parentOp = region->getContainingOp();
|
||||||
|
if (!parentOp || !parentOp->isRegistered() ||
|
||||||
|
parentOp->isKnownIsolatedFromAbove() || !parentOp->getBlock())
|
||||||
|
return region;
|
||||||
|
// Traverse up the parent looking for an insertion region.
|
||||||
|
op = parentOp;
|
||||||
|
}
|
||||||
|
llvm_unreachable("expected valid insertion region");
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A utility function used to materialize a constant for a given attribute and
|
||||||
|
/// type. On success, a valid constant value is returned. Otherwise, null is
|
||||||
|
/// returned
|
||||||
|
static Operation *materializeConstant(Dialect *dialect, OpBuilder &builder,
|
||||||
|
Attribute value, Type type,
|
||||||
|
Location loc) {
|
||||||
|
auto insertPt = builder.getInsertionPoint();
|
||||||
|
(void)insertPt;
|
||||||
|
|
||||||
|
// Ask the dialect to materialize a constant operation for this value.
|
||||||
|
if (auto *constOp = dialect->materializeConstant(builder, value, type, loc)) {
|
||||||
|
assert(insertPt == builder.getInsertionPoint());
|
||||||
|
assert(matchPattern(constOp, m_Constant(&value)));
|
||||||
|
return constOp;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the dialect is unable to materialize a constant, check to see if the
|
||||||
|
// standard constant can be used.
|
||||||
|
if (ConstantOp::isBuildableWith(value, type))
|
||||||
|
return builder.create<ConstantOp>(loc, type, value);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// OperationFolder
|
// OperationFolder
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -37,9 +78,6 @@ LogicalResult OperationFolder::tryToFold(
|
||||||
Operation *op,
|
Operation *op,
|
||||||
llvm::function_ref<void(Operation *)> processGeneratedConstants,
|
llvm::function_ref<void(Operation *)> processGeneratedConstants,
|
||||||
llvm::function_ref<void(Operation *)> preReplaceAction) {
|
llvm::function_ref<void(Operation *)> preReplaceAction) {
|
||||||
assert(op->getFunction() == function &&
|
|
||||||
"cannot constant fold op from another function");
|
|
||||||
|
|
||||||
// If this is a unique'd constant, return failure as we know that it has
|
// If this is a unique'd constant, return failure as we know that it has
|
||||||
// already been folded.
|
// already been folded.
|
||||||
if (referencedDialects.count(op))
|
if (referencedDialects.count(op))
|
||||||
|
@ -70,9 +108,6 @@ LogicalResult OperationFolder::tryToFold(
|
||||||
/// Notifies that the given constant `op` should be remove from this
|
/// Notifies that the given constant `op` should be remove from this
|
||||||
/// OperationFolder's internal bookkeeping.
|
/// OperationFolder's internal bookkeeping.
|
||||||
void OperationFolder::notifyRemoval(Operation *op) {
|
void OperationFolder::notifyRemoval(Operation *op) {
|
||||||
assert(op->getFunction() == function &&
|
|
||||||
"cannot remove constant from another function");
|
|
||||||
|
|
||||||
// Check to see if this operation is uniqued within the folder.
|
// Check to see if this operation is uniqued within the folder.
|
||||||
auto it = referencedDialects.find(op);
|
auto it = referencedDialects.find(op);
|
||||||
if (it == referencedDialects.end())
|
if (it == referencedDialects.end())
|
||||||
|
@ -84,6 +119,9 @@ void OperationFolder::notifyRemoval(Operation *op) {
|
||||||
matchPattern(op, m_Constant(&constValue));
|
matchPattern(op, m_Constant(&constValue));
|
||||||
assert(constValue);
|
assert(constValue);
|
||||||
|
|
||||||
|
// Get the constant map that this operation was uniqued in.
|
||||||
|
auto &uniquedConstants = foldScopes[getInsertionRegion(op)];
|
||||||
|
|
||||||
// Erase all of the references to this operation.
|
// Erase all of the references to this operation.
|
||||||
auto type = op->getResult(0)->getType();
|
auto type = op->getResult(0)->getType();
|
||||||
for (auto *dialect : it->second)
|
for (auto *dialect : it->second)
|
||||||
|
@ -91,37 +129,11 @@ void OperationFolder::notifyRemoval(Operation *op) {
|
||||||
referencedDialects.erase(it);
|
referencedDialects.erase(it);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A utility function used to materialize a constant for a given attribute and
|
|
||||||
/// type. On success, a valid constant value is returned. Otherwise, null is
|
|
||||||
/// returned
|
|
||||||
static Operation *materializeConstant(Dialect *dialect, OpBuilder &builder,
|
|
||||||
Attribute value, Type type,
|
|
||||||
Location loc) {
|
|
||||||
auto insertPt = builder.getInsertionPoint();
|
|
||||||
(void)insertPt;
|
|
||||||
|
|
||||||
// Ask the dialect to materialize a constant operation for this value.
|
|
||||||
if (auto *constOp = dialect->materializeConstant(builder, value, type, loc)) {
|
|
||||||
assert(insertPt == builder.getInsertionPoint());
|
|
||||||
assert(matchPattern(constOp, m_Constant(&value)));
|
|
||||||
return constOp;
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the dialect is unable to materialize a constant, check to see if the
|
|
||||||
// standard constant can be used.
|
|
||||||
if (ConstantOp::isBuildableWith(value, type))
|
|
||||||
return builder.create<ConstantOp>(loc, type, value);
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Tries to perform folding on the given `op`. If successful, populates
|
/// Tries to perform folding on the given `op`. If successful, populates
|
||||||
/// `results` with the results of the folding.
|
/// `results` with the results of the folding.
|
||||||
LogicalResult OperationFolder::tryToFold(
|
LogicalResult OperationFolder::tryToFold(
|
||||||
Operation *op, SmallVectorImpl<Value *> &results,
|
Operation *op, SmallVectorImpl<Value *> &results,
|
||||||
llvm::function_ref<void(Operation *)> processGeneratedConstants) {
|
llvm::function_ref<void(Operation *)> processGeneratedConstants) {
|
||||||
assert(op->getFunction() == function &&
|
|
||||||
"cannot constant fold op from another function");
|
|
||||||
|
|
||||||
SmallVector<Attribute, 8> operandConstants;
|
SmallVector<Attribute, 8> operandConstants;
|
||||||
SmallVector<OpFoldResult, 8> foldResults;
|
SmallVector<OpFoldResult, 8> foldResults;
|
||||||
|
|
||||||
|
@ -148,9 +160,14 @@ LogicalResult OperationFolder::tryToFold(
|
||||||
return success();
|
return success();
|
||||||
assert(foldResults.size() == op->getNumResults());
|
assert(foldResults.size() == op->getNumResults());
|
||||||
|
|
||||||
// Create a builder to insert new operations into the entry block.
|
// Create a builder to insert new operations into the entry block of the
|
||||||
auto &entry = function->getBody().front();
|
// insertion region.
|
||||||
OpBuilder builder(&entry, entry.empty() ? entry.end() : entry.begin());
|
auto *insertionRegion = getInsertionRegion(op);
|
||||||
|
auto &entry = insertionRegion->front();
|
||||||
|
OpBuilder builder(&entry, entry.begin());
|
||||||
|
|
||||||
|
// Get the constant map for the insertion region of this operation.
|
||||||
|
auto &uniquedConstants = foldScopes[insertionRegion];
|
||||||
|
|
||||||
// Create the result constants and replace the results.
|
// Create the result constants and replace the results.
|
||||||
auto *dialect = op->getDialect();
|
auto *dialect = op->getDialect();
|
||||||
|
@ -166,8 +183,9 @@ LogicalResult OperationFolder::tryToFold(
|
||||||
// Check to see if there is a canonicalized version of this constant.
|
// Check to see if there is a canonicalized version of this constant.
|
||||||
auto *res = op->getResult(i);
|
auto *res = op->getResult(i);
|
||||||
Attribute attrRepl = foldResults[i].get<Attribute>();
|
Attribute attrRepl = foldResults[i].get<Attribute>();
|
||||||
if (auto *constOp = tryGetOrCreateConstant(dialect, builder, attrRepl,
|
if (auto *constOp =
|
||||||
res->getType(), op->getLoc())) {
|
tryGetOrCreateConstant(uniquedConstants, dialect, builder, attrRepl,
|
||||||
|
res->getType(), op->getLoc())) {
|
||||||
results.push_back(constOp->getResult(0));
|
results.push_back(constOp->getResult(0));
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
@ -192,10 +210,9 @@ LogicalResult OperationFolder::tryToFold(
|
||||||
|
|
||||||
/// Try to get or create a new constant entry. On success this returns the
|
/// Try to get or create a new constant entry. On success this returns the
|
||||||
/// constant operation value, nullptr otherwise.
|
/// constant operation value, nullptr otherwise.
|
||||||
Operation *OperationFolder::tryGetOrCreateConstant(Dialect *dialect,
|
Operation *OperationFolder::tryGetOrCreateConstant(
|
||||||
OpBuilder &builder,
|
ConstantMap &uniquedConstants, Dialect *dialect, OpBuilder &builder,
|
||||||
Attribute value, Type type,
|
Attribute value, Type type, Location loc) {
|
||||||
Location loc) {
|
|
||||||
// Check if an existing mapping already exists.
|
// Check if an existing mapping already exists.
|
||||||
auto constKey = std::make_tuple(dialect, value, type);
|
auto constKey = std::make_tuple(dialect, value, type);
|
||||||
auto *&constInst = uniquedConstants[constKey];
|
auto *&constInst = uniquedConstants[constKey];
|
||||||
|
|
|
@ -143,9 +143,7 @@ private:
|
||||||
/// Perform the rewrites.
|
/// Perform the rewrites.
|
||||||
bool GreedyPatternRewriteDriver::simplifyFunction(int maxIterations) {
|
bool GreedyPatternRewriteDriver::simplifyFunction(int maxIterations) {
|
||||||
Region *region = getRegion();
|
Region *region = getRegion();
|
||||||
|
OperationFolder helper;
|
||||||
// TODO(riverriddle) OperationFolder should take a region to insert into.
|
|
||||||
OperationFolder helper(region->getContainingFunction());
|
|
||||||
|
|
||||||
// Add the given operation to the worklist.
|
// Add the given operation to the worklist.
|
||||||
auto collectOps = [this](Operation *op) { addToWorklist(op); };
|
auto collectOps = [this](Operation *op) { addToWorklist(op); };
|
||||||
|
|
|
@ -421,6 +421,7 @@ func @fold_extract_element(%arg0 : index) -> (f32, f16, f16, i32) {
|
||||||
return %ext_1, %ext_2, %ext_3, %ext_4 : f32, f16, f16, i32
|
return %ext_1, %ext_2, %ext_3, %ext_4 : f32, f16, f16, i32
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @fold_rank
|
// CHECK-LABEL: func @fold_rank
|
||||||
func @fold_rank() -> (index) {
|
func @fold_rank() -> (index) {
|
||||||
|
@ -434,3 +435,24 @@ func @fold_rank() -> (index) {
|
||||||
return %rank_0 : index
|
return %rank_0 : index
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @nested_isolated_region
|
||||||
|
func @nested_isolated_region() {
|
||||||
|
// CHECK-NEXT: func @isolated_op
|
||||||
|
// CHECK-NEXT: constant 2
|
||||||
|
func @isolated_op() {
|
||||||
|
%0 = constant 1 : i32
|
||||||
|
%2 = addi %0, %0 : i32
|
||||||
|
"foo.yield"(%2) : (i32) -> ()
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK: "foo.unknown_region"
|
||||||
|
// CHECK-NEXT: constant 2
|
||||||
|
"foo.unknown_region"() ({
|
||||||
|
%0 = constant 1 : i32
|
||||||
|
%2 = addi %0, %0 : i32
|
||||||
|
"foo.yield"(%2) : (i32) -> ()
|
||||||
|
}) : () -> ()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
|
@ -53,17 +53,15 @@ void TestConstantFold::foldOperation(Operation *op, OperationFolder &helper) {
|
||||||
void TestConstantFold::runOnFunction() {
|
void TestConstantFold::runOnFunction() {
|
||||||
existingConstants.clear();
|
existingConstants.clear();
|
||||||
|
|
||||||
auto &f = getFunction();
|
|
||||||
OperationFolder helper(&f);
|
|
||||||
|
|
||||||
// Collect and fold the operations within the function.
|
// Collect and fold the operations within the function.
|
||||||
SmallVector<Operation *, 8> ops;
|
SmallVector<Operation *, 8> ops;
|
||||||
f.walk([&](Operation *op) { ops.push_back(op); });
|
getFunction().walk([&](Operation *op) { ops.push_back(op); });
|
||||||
|
|
||||||
// Fold the constants in reverse so that the last generated constants from
|
// Fold the constants in reverse so that the last generated constants from
|
||||||
// folding are at the beginning. This creates somewhat of a linear ordering to
|
// folding are at the beginning. This creates somewhat of a linear ordering to
|
||||||
// the newly generated constants that matches the operation order and improves
|
// the newly generated constants that matches the operation order and improves
|
||||||
// the readability of test cases.
|
// the readability of test cases.
|
||||||
|
OperationFolder helper;
|
||||||
for (Operation *op : llvm::reverse(ops))
|
for (Operation *op : llvm::reverse(ops))
|
||||||
foldOperation(op, helper);
|
foldOperation(op, helper);
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue