Refactor the traversal of operations to Convert in DialectConversion.

This cl changes the way that operations/blocks to convert are collected/traversed so that parent region operations can be legalized before their bodies. Most RewritePatterns for region operations assume that the entry arguments to each region are yet to be converted. Given that the bodies are currently converted first, this makes it difficult to fit these patterns into the same run as one converting types.

The operations/blocks to convert are now collected before any legalization has run, which simplifies the conversion logic itself, as legalization may insert new operations, move blocks, etc.

PiperOrigin-RevId: 258170158
This commit is contained in:
River Riddle 2019-07-15 08:40:11 -07:00 committed by Mehdi Amini
parent d2f1ed5137
commit 7d1e1e6721
1 changed files with 75 additions and 95 deletions

View File

@ -170,10 +170,13 @@ void ArgConverter::applyRewrites() {
continue;
}
// If mapping is from type to itself, replace the remaining uses and drop
// the cast operation.
if (op->getNumOperands() == 1 &&
op->getResult(0)->getType() == op->getOperand(0)->getType()) {
// If mapping is 1-1, replace the remaining uses and drop the cast
// operation.
// FIXME(riverriddle) This should check that the result type and operand
// type are the same, otherwise it should force a conversion to be
// materialized. This works around a current limitation with regards to
// region entry argument type conversion.
if (op->getNumOperands() == 1) {
op->getResult(0)->replaceAllUsesWith(op->getOperand(0));
op->destroy();
continue;
@ -229,7 +232,7 @@ void ArgConverter::convertSignature(
LogicalResult ArgConverter::convertArguments(Block *block,
BlockAndValueMapping &mapping) {
unsigned origArgCount = block->getNumArguments();
if (origArgCount == 0)
if (origArgCount == 0 || argMapping.count(block))
return success();
// Convert the types of each of the block arguments.
@ -440,10 +443,14 @@ struct DialectConversionRewriter final : public PatternRewriter {
for (auto &region : repl.op->getRegions())
argConverter.cancelPendingRewrites(region);
}
repl.op->erase();
}
// In a second pass, erase all of the replaced operations in reverse. This
// allows processing nested operations before their parent region is
// destroyed.
for (auto &repl : llvm::reverse(replacements))
repl.op->erase();
argConverter.applyRewrites();
}
@ -841,7 +848,7 @@ namespace {
// TypeConverter object is provided, then the types of block arguments will be
// converted using the appropriate 'convertType' calls.
struct FunctionConverter {
explicit FunctionConverter(MLIRContext *ctx, ConversionTarget &target,
explicit FunctionConverter(ConversionTarget &target,
OwningRewritePatternList &patterns,
TypeConverter *conversion = nullptr)
: typeConverter(conversion), opLegalizer(target, patterns) {}
@ -853,22 +860,18 @@ struct FunctionConverter {
convertFunction(FuncOp f,
TypeConverter::SignatureConversion *signatureConversion);
/// Converts the given region starting from the entry block and following the
/// block successors. Returns failure on error, success otherwise. Prints
/// error messages at `loc`.
LogicalResult convertRegion(DialectConversionRewriter &rewriter,
Region &region, bool convertEntryTypes = true);
private:
/// Converts a block or operation with the given rewriter.
LogicalResult convert(DialectConversionRewriter &rewriter,
llvm::PointerUnion<Operation *, Block *> &ptr);
/// Converts a block by traversing its operations sequentially, attempting to
/// match a pattern. If there is no match, recurses the operations regions if
/// it has any.
//
/// After converting operations, traverses the successor blocks unless they
/// have been visited already as indicated in `visitedBlocks`.
LogicalResult convertBlock(DialectConversionRewriter &rewriter, Block *block,
DenseSet<Block *> &visitedBlocks);
/// Recursively collect all of the blocks, and operations, to convert from
/// within 'region'.
LogicalResult computeConversionSet(
Region &region,
std::vector<llvm::PointerUnion<Operation *, Block *>> &toConvert);
/// Pointer to a specific dialect conversion info.
/// Pointer to the type converter.
TypeConverter *typeConverter;
/// The legalizer to use when converting operations.
@ -876,84 +879,56 @@ struct FunctionConverter {
};
} // end anonymous namespace
LogicalResult
FunctionConverter::convertBlock(DialectConversionRewriter &rewriter,
Block *block,
DenseSet<Block *> &visitedBlocks) {
// First, add the current block to the list of visited blocks.
visitedBlocks.insert(block);
if (block->empty())
/// Recursively collect all of the blocks to convert from within 'region'.
LogicalResult FunctionConverter::computeConversionSet(
Region &region,
std::vector<llvm::PointerUnion<Operation *, Block *>> &toConvert) {
if (region.empty())
return success();
// Preserve the successors before rewriting the operations.
SmallVector<Block *, 4> successors(block->getSuccessors());
// Traverse starting from the entry block.
SmallVector<Block *, 16> worklist(1, &region.front());
DenseSet<Block *> visitedBlocks;
visitedBlocks.insert(&region.front());
while (!worklist.empty()) {
auto *block = worklist.pop_back_val();
// Iterate over ops and convert them. Since the conversion may split the
// block, we eagerly take the pointer to the next operation in it. Splitting
// moves the operations from one block to another, so this will keep
// considering the original list of operations independently of the block
// within which they are currently located. This relies on iplist node API
// to get the next node in the list witout knowing which list it is, iterators
// are unsuitable because block splitting invalidates all iterators following
// the current one. Any operation inserted by the conversion, independently of
// its parent block, will be recursively legalized independently of this
// function.
Operation *current = &block->front();
Operation *next = nullptr;
do {
next = current->getNextNode();
// Traverse any held regions.
for (auto &region : current->getRegions())
if (!region.empty() && failed(convertRegion(rewriter, region)))
return failure();
// We only need to process blocks if we are changing argument types.
if (typeConverter)
toConvert.emplace_back(block);
// Legalize the current operation.
(void)opLegalizer.legalize(current, rewriter);
} while ((current = next));
// Compute the conversion set of each of the nested operations.
for (auto &op : *block) {
toConvert.emplace_back(&op);
for (auto &region : op.getRegions())
computeConversionSet(region, toConvert);
}
// Recurse to children that haven't been visited.
for (Block *succ : successors) {
if (visitedBlocks.count(succ))
continue;
if (failed(convertBlock(rewriter, succ, visitedBlocks)))
return failure();
for (Block *succ : block->getSuccessors())
if (visitedBlocks.insert(succ).second)
worklist.push_back(succ);
}
// Check that all blocks in the region were visited.
if (llvm::any_of(llvm::drop_begin(region.getBlocks(), 1),
[&](Block &block) { return !visitedBlocks.count(&block); }))
return emitError(region.getLoc(), "unreachable blocks were not converted");
return success();
}
/// Converts a block or operation with the given rewriter.
LogicalResult
FunctionConverter::convertRegion(DialectConversionRewriter &rewriter,
Region &region, bool convertEntryTypes) {
assert(!region.empty() && "expected non-empty region");
// Create the arguments of each of the blocks in the region. If a type
// converter was not provided, then we don't need to change any of the block
// types.
if (typeConverter) {
for (Block &block :
llvm::drop_begin(region.getBlocks(), convertEntryTypes ? 0 : 1)) {
if (failed(
rewriter.argConverter.convertArguments(&block, rewriter.mapping)))
return failure();
}
FunctionConverter::convert(DialectConversionRewriter &rewriter,
llvm::PointerUnion<Operation *, Block *> &ptr) {
// If this is a block, then convert the types of each of the arguments.
if (auto *block = ptr.dyn_cast<Block *>()) {
assert(typeConverter && "expected valid type converter");
return rewriter.argConverter.convertArguments(block, rewriter.mapping);
}
// Store the number of blocks before conversion (new blocks may be added due
// to splits or moves, but the operations in them will be processed
// elsewhere).
unsigned numBlocks = std::distance(region.begin(), region.end());
// Start a DFS-order traversal of the CFG to make sure defs are converted
// before uses in dominated blocks.
llvm::DenseSet<Block *> visitedBlocks;
if (failed(convertBlock(rewriter, &region.front(), visitedBlocks)))
return failure();
// If some blocks are not reachable through successor chains, they should have
// been removed by the DCE before this.
if (visitedBlocks.size() != numBlocks)
return emitError(region.getLoc(), "unreachable blocks were not converted");
// Otherwise, this is an operation to legalize.
(void)opLegalizer.legalize(ptr.get<Operation *>(), rewriter);
return success();
}
@ -971,13 +946,18 @@ LogicalResult FunctionConverter::convertFunction(
&f.getBody().front(), *signatureConversion, rewriter.mapping);
}
// Rewrite the function body.
if (failed(
convertRegion(rewriter, f.getBody(), /*convertEntryTypes=*/false))) {
// Reset any of the generated rewrites.
/// Compute the set of operations and blocks to convert.
std::vector<llvm::PointerUnion<Operation *, Block *>> toConvert;
if (failed(computeConversionSet(f.getBody(), toConvert)))
return failure();
// Convert each operation/block and discard rewrites on failure.
for (auto &it : toConvert) {
if (failed(convert(rewriter, it))) {
rewriter.discardRewrites();
return failure();
}
}
// Otherwise the body conversion succeeded, so apply all rewrites.
rewriter.applyRewrites();
@ -1146,7 +1126,7 @@ LogicalResult mlir::applyConversionPatterns(
// Build the function converter.
auto *ctx = fns.front().getContext();
FunctionConverter funcConverter(ctx, target, patterns, &converter);
FunctionConverter funcConverter(target, patterns, &converter);
// Try to convert each of the functions within the module.
SmallVector<NamedAttributeList, 4> argAttrs;
@ -1178,6 +1158,6 @@ LogicalResult
mlir::applyConversionPatterns(FuncOp fn, ConversionTarget &target,
OwningRewritePatternList &&patterns) {
// Convert the body of this function.
FunctionConverter converter(fn.getContext(), target, patterns);
FunctionConverter converter(target, patterns);
return converter.convertFunction(fn, /*signatureConversion=*/nullptr);
}